diff --git a/Sources/HuggingFace/Hub/HubClient+Files.swift b/Sources/HuggingFace/Hub/HubClient+Files.swift index 2511f05..e2eb7ee 100644 --- a/Sources/HuggingFace/Hub/HubClient+Files.swift +++ b/Sources/HuggingFace/Hub/HubClient+Files.swift @@ -245,6 +245,7 @@ public extension HubClient { /// - revision: Git revision /// - useRaw: Use raw endpoint /// - cachePolicy: Cache policy for the request + /// - resumable: Whether to enable resumable downloads (Apple platforms only) /// - progress: Optional Progress object to track download progress /// - Returns: Final destination URL func downloadFile( @@ -255,6 +256,7 @@ public extension HubClient { revision: String = "main", useRaw: Bool = false, cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy, + resumable: Bool = false, progress: Progress? = nil ) async throws -> URL { // Check cache first @@ -289,15 +291,33 @@ public extension HubClient { var request = try await httpClient.createRequest(.get, url: url) request.cachePolicy = cachePolicy - #if canImport(FoundationNetworking) - let (tempURL, response) = try await session.asyncDownload(for: request, progress: progress) - #else - let (tempURL, response) = try await session.download( - for: request, - delegate: progress.map { DownloadProgressDelegate(progress: $0) } - ) - #endif - _ = try httpClient.validateResponse(response, data: nil) + let (tempURL, response): (URL, URLResponse) + do { + #if canImport(FoundationNetworking) + (tempURL, response) = try await session.asyncDownload(for: request, progress: progress) + #else + (tempURL, response) = try await session.download( + for: request, + delegate: progress.map { DownloadProgressDelegate(progress: $0) } + ) + #endif + _ = try httpClient.validateResponse(response, data: nil) + } catch { + guard resumable else { throw error } + + #if !canImport(FoundationNetworking) + // Best-effort: persist resume data if URLSession provides it so a later call can resume. + if let resumeData = (error as NSError).userInfo[NSURLSessionDownloadTaskResumeData] as? Data { + let resumeDataURL = destination.appendingPathExtension("resumeData") + try? FileManager.default.createDirectory( + at: resumeDataURL.deletingLastPathComponent(), + withIntermediateDirectories: true + ) + try? resumeData.write(to: resumeDataURL, options: .atomic) + } + #endif + throw error + } // Store in cache before moving to destination if let cache = cache, @@ -589,6 +609,7 @@ public extension HubClient { /// - destination: Local destination directory /// - revision: Git revision (branch, tag, or commit) /// - matching: Glob patterns to filter files (empty array downloads all files) + /// - resumable: Whether to enable resumable downloads (Apple platforms only) /// - progressHandler: Optional closure called with progress updates /// - Returns: URL to the local snapshot directory func downloadSnapshot( @@ -597,6 +618,7 @@ public extension HubClient { to destination: URL, revision: String = "main", matching globs: [String] = [], + resumable: Bool = false, progressHandler: ((Progress) -> Void)? = nil ) async throws -> URL { let filenames = try await listFiles(in: repo, kind: kind, revision: revision, recursive: true) @@ -622,6 +644,7 @@ public extension HubClient { to: fileDestination, kind: kind, revision: revision, + resumable: resumable, progress: fileProgress ) @@ -635,6 +658,100 @@ public extension HubClient { progressHandler?(progress) return destination } + + #if !canImport(FoundationNetworking) + /// Resume a repository snapshot download to a local directory. + /// + /// This behaves like `downloadSnapshot`, but will attempt to resume individual file downloads + /// if a sidecar resume-data file exists at `/.resumeData`. + /// + /// If the final file already exists at the destination, it is skipped. + /// If resume data exists, `URLSession.download(resumeFrom:)` is used via `resumeDownloadFile(...)`. + /// Otherwise a normal download is performed via `downloadFile(...)`. + /// + /// - Important: This method does not create resume data by itself; it only consumes it. + func resumeDownloadSnapshot( + of repo: Repo.ID, + kind: Repo.Kind = .model, + to destination: URL, + revision: String = "main", + matching globs: [String] = [], + progressHandler: ((Progress) -> Void)? = nil + ) async throws -> URL { + let filenames = try await listFiles(in: repo, kind: kind, revision: revision, recursive: true) + .map(\.path) + .filter { filename in + guard !globs.isEmpty else { return true } + return globs.contains { glob in + fnmatch(glob, filename, 0) == 0 + } + } + + let progress = Progress(totalUnitCount: Int64(filenames.count)) + progressHandler?(progress) + + for filename in filenames { + let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1) + let fileDestination = destination.appendingPathComponent(filename) + let resumeDataURL = fileDestination.appendingPathExtension("resumeData") + + // If the final file already exists, treat it as completed. + if FileManager.default.fileExists(atPath: fileDestination.path) { + fileProgress.completedUnitCount = 100 + continue + } + + // Ensure parent directory exists. + try FileManager.default.createDirectory( + at: fileDestination.deletingLastPathComponent(), + withIntermediateDirectories: true + ) + + if FileManager.default.fileExists(atPath: resumeDataURL.path) { + do { + let resumeData = try Data(contentsOf: resumeDataURL) + _ = try await resumeDownloadFile( + resumeData: resumeData, + to: fileDestination, + progress: fileProgress + ) + try? FileManager.default.removeItem(at: resumeDataURL) + } catch { + // If resume fails, delete stale resume data and fall back to a full download. + try? FileManager.default.removeItem(at: resumeDataURL) + _ = try await downloadFile( + at: filename, + from: repo, + to: fileDestination, + kind: kind, + revision: revision, + resumable: true, + progress: fileProgress + ) + } + } else { + _ = try await downloadFile( + at: filename, + from: repo, + to: fileDestination, + kind: kind, + revision: revision, + resumable: true, + progress: fileProgress + ) + } + + if Task.isCancelled { + return destination + } + + fileProgress.completedUnitCount = 100 + } + + progressHandler?(progress) + return destination + } + #endif } // MARK: -