Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 126 additions & 9 deletions Sources/HuggingFace/Hub/HubClient+Files.swift
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ public extension HubClient {
/// - revision: Git revision
/// - useRaw: Use raw endpoint
/// - cachePolicy: Cache policy for the request
/// - resumable: Whether to enable resumable downloads (Apple platforms only)
/// - progress: Optional Progress object to track download progress
/// - Returns: Final destination URL
func downloadFile(
Expand All @@ -255,6 +256,7 @@ public extension HubClient {
revision: String = "main",
useRaw: Bool = false,
cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy,
resumable: Bool = false,
progress: Progress? = nil
) async throws -> URL {
// Check cache first
Expand Down Expand Up @@ -289,15 +291,33 @@ public extension HubClient {
var request = try await httpClient.createRequest(.get, url: url)
request.cachePolicy = cachePolicy

#if canImport(FoundationNetworking)
let (tempURL, response) = try await session.asyncDownload(for: request, progress: progress)
#else
let (tempURL, response) = try await session.download(
for: request,
delegate: progress.map { DownloadProgressDelegate(progress: $0) }
)
#endif
_ = try httpClient.validateResponse(response, data: nil)
let (tempURL, response): (URL, URLResponse)
do {
#if canImport(FoundationNetworking)
(tempURL, response) = try await session.asyncDownload(for: request, progress: progress)
#else
(tempURL, response) = try await session.download(
for: request,
delegate: progress.map { DownloadProgressDelegate(progress: $0) }
)
#endif
_ = try httpClient.validateResponse(response, data: nil)
} catch {
guard resumable else { throw error }

#if !canImport(FoundationNetworking)
// Best-effort: persist resume data if URLSession provides it so a later call can resume.
if let resumeData = (error as NSError).userInfo[NSURLSessionDownloadTaskResumeData] as? Data {
let resumeDataURL = destination.appendingPathExtension("resumeData")
try? FileManager.default.createDirectory(
at: resumeDataURL.deletingLastPathComponent(),
withIntermediateDirectories: true
)
try? resumeData.write(to: resumeDataURL, options: .atomic)
}
#endif
throw error
}

// Store in cache before moving to destination
if let cache = cache,
Expand Down Expand Up @@ -589,6 +609,7 @@ public extension HubClient {
/// - destination: Local destination directory
/// - revision: Git revision (branch, tag, or commit)
/// - matching: Glob patterns to filter files (empty array downloads all files)
/// - resumable: Whether to enable resumable downloads (Apple platforms only)
/// - progressHandler: Optional closure called with progress updates
/// - Returns: URL to the local snapshot directory
func downloadSnapshot(
Expand All @@ -597,6 +618,7 @@ public extension HubClient {
to destination: URL,
revision: String = "main",
matching globs: [String] = [],
resumable: Bool = false,
progressHandler: ((Progress) -> Void)? = nil
) async throws -> URL {
let filenames = try await listFiles(in: repo, kind: kind, revision: revision, recursive: true)
Expand All @@ -622,6 +644,7 @@ public extension HubClient {
to: fileDestination,
kind: kind,
revision: revision,
resumable: resumable,
progress: fileProgress
)

Expand All @@ -635,6 +658,100 @@ public extension HubClient {
progressHandler?(progress)
return destination
}

#if !canImport(FoundationNetworking)
/// Resume a repository snapshot download to a local directory.
///
/// This behaves like `downloadSnapshot`, but will attempt to resume individual file downloads
/// if a sidecar resume-data file exists at `<destination>/<path>.resumeData`.
///
/// If the final file already exists at the destination, it is skipped.
/// If resume data exists, `URLSession.download(resumeFrom:)` is used via `resumeDownloadFile(...)`.
/// Otherwise a normal download is performed via `downloadFile(...)`.
///
/// - Important: This method does not create resume data by itself; it only consumes it.
func resumeDownloadSnapshot(
of repo: Repo.ID,
kind: Repo.Kind = .model,
to destination: URL,
revision: String = "main",
matching globs: [String] = [],
progressHandler: ((Progress) -> Void)? = nil
) async throws -> URL {
let filenames = try await listFiles(in: repo, kind: kind, revision: revision, recursive: true)
.map(\.path)
.filter { filename in
guard !globs.isEmpty else { return true }
return globs.contains { glob in
fnmatch(glob, filename, 0) == 0
}
}

let progress = Progress(totalUnitCount: Int64(filenames.count))
progressHandler?(progress)

for filename in filenames {
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1)
let fileDestination = destination.appendingPathComponent(filename)
let resumeDataURL = fileDestination.appendingPathExtension("resumeData")

// If the final file already exists, treat it as completed.
if FileManager.default.fileExists(atPath: fileDestination.path) {
fileProgress.completedUnitCount = 100
continue
}

// Ensure parent directory exists.
try FileManager.default.createDirectory(
at: fileDestination.deletingLastPathComponent(),
withIntermediateDirectories: true
)

if FileManager.default.fileExists(atPath: resumeDataURL.path) {
do {
let resumeData = try Data(contentsOf: resumeDataURL)
_ = try await resumeDownloadFile(
resumeData: resumeData,
to: fileDestination,
progress: fileProgress
)
try? FileManager.default.removeItem(at: resumeDataURL)
} catch {
// If resume fails, delete stale resume data and fall back to a full download.
try? FileManager.default.removeItem(at: resumeDataURL)
_ = try await downloadFile(
at: filename,
from: repo,
to: fileDestination,
kind: kind,
revision: revision,
resumable: true,
progress: fileProgress
)
}
} else {
_ = try await downloadFile(
at: filename,
from: repo,
to: fileDestination,
kind: kind,
revision: revision,
resumable: true,
progress: fileProgress
)
}

if Task.isCancelled {
return destination
}

fileProgress.completedUnitCount = 100
}

progressHandler?(progress)
return destination
}
#endif
}

// MARK: -
Expand Down