Skip to content

Commit 0e67916

Browse files
adrgrondinpcuenca
andauthored
Add overloads for snapshot functions to expose download speed (#225)
* Add download speed info * Add download speed test * swiftformat --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent ad43b7f commit 0e67916

File tree

5 files changed

+124
-13
lines changed

5 files changed

+124
-13
lines changed

Sources/Hub/Downloader.swift

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ final class Downloader: NSObject, Sendable, ObservableObject {
1717

1818
enum DownloadState {
1919
case notStarted
20-
case downloading(Double)
20+
case downloading(Double, Double?)
2121
case completed(URL)
2222
case failed(Error)
2323
}
@@ -149,12 +149,12 @@ final class Downloader: NSObject, Sendable, ObservableObject {
149149
// Calculate and show initial progress
150150
if let expectedSize = await self.downloadResumeState.expectedSize, expectedSize > 0 {
151151
let initialProgress = Double(resumeSize) / Double(expectedSize)
152-
await self.broadcaster.broadcast(state: .downloading(initialProgress))
152+
await self.broadcaster.broadcast(state: .downloading(initialProgress, nil))
153153
} else {
154-
await self.broadcaster.broadcast(state: .downloading(0))
154+
await self.broadcaster.broadcast(state: .downloading(0, nil))
155155
}
156156
} else {
157-
await self.broadcaster.broadcast(state: .downloading(0))
157+
await self.broadcaster.broadcast(state: .downloading(0, nil))
158158
}
159159

160160
request.timeoutInterval = timeout
@@ -227,6 +227,11 @@ final class Downloader: NSObject, Sendable, ObservableObject {
227227
// Create a buffer to collect bytes before writing to disk
228228
var buffer = Data(capacity: chunkSize)
229229

230+
// Track speed (bytes per second) using sampling between broadcasts
231+
var lastSampleTime = Date()
232+
var totalDownloadedLocal = await downloadResumeState.downloadedSize
233+
var lastSampleBytes = totalDownloadedLocal
234+
230235
var newNumRetries = numRetries
231236
do {
232237
for try await byte in asyncBytes {
@@ -237,17 +242,28 @@ final class Downloader: NSObject, Sendable, ObservableObject {
237242
try tempFile.write(contentsOf: buffer)
238243
buffer.removeAll(keepingCapacity: true)
239244

245+
totalDownloadedLocal += chunkSize
240246
await downloadResumeState.incDownloadedSize(chunkSize)
241247
newNumRetries = 5
242248
guard let expectedSize = await downloadResumeState.expectedSize else { continue }
243-
let progress = await expectedSize != 0 ? Double(downloadResumeState.downloadedSize) / Double(expectedSize) : 0
244-
await broadcaster.broadcast(state: .downloading(progress))
249+
let progress = expectedSize != 0 ? Double(totalDownloadedLocal) / Double(expectedSize) : 0
250+
251+
// Compute instantaneous speed based on bytes since last broadcast
252+
let now = Date()
253+
let elapsed = now.timeIntervalSince(lastSampleTime)
254+
let deltaBytes = totalDownloadedLocal - lastSampleBytes
255+
let speed = elapsed > 0 ? Double(deltaBytes) / elapsed : nil
256+
lastSampleTime = now
257+
lastSampleBytes = totalDownloadedLocal
258+
259+
await broadcaster.broadcast(state: .downloading(progress, speed))
245260
}
246261
}
247262
}
248263

249264
if !buffer.isEmpty {
250265
try tempFile.write(contentsOf: buffer)
266+
totalDownloadedLocal += buffer.count
251267
await downloadResumeState.incDownloadedSize(buffer.count)
252268
buffer.removeAll(keepingCapacity: true)
253269
newNumRetries = 5
@@ -298,7 +314,7 @@ final class Downloader: NSObject, Sendable, ObservableObject {
298314
extension Downloader: URLSessionDownloadDelegate {
299315
func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) {
300316
Task {
301-
await self.broadcaster.broadcast(state: .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite)))
317+
await self.broadcaster.broadcast(state: .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite), nil))
302318
}
303319
}
304320

Sources/Hub/HubApi.swift

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ public extension HubApi {
443443
/// We'll probably need to support Combine as well to play well with Swift UI
444444
/// (See for example PipelineLoader in swift-coreml-diffusers)
445445
@discardableResult
446-
func download(progressHandler: @escaping (Double) -> Void) async throws -> URL {
446+
func download(progressHandler: @escaping (Double, Double?) -> Void) async throws -> URL {
447447
let localMetadata = try hub.readDownloadMetadata(metadataPath: metadataDestination)
448448
let remoteMetadata = try await hub.getFileMetadata(url: source)
449449

@@ -499,8 +499,8 @@ public extension HubApi {
499499
switch state {
500500
case .notStarted:
501501
continue
502-
case let .downloading(progress):
503-
progressHandler(progress)
502+
case let .downloading(progress, speed):
503+
progressHandler(progress, speed)
504504
case let .failed(error):
505505
throw error
506506
case .completed:
@@ -583,8 +583,12 @@ public extension HubApi {
583583
backgroundSession: useBackgroundSession
584584
)
585585

586-
try await downloader.download { fractionDownloaded in
586+
try await downloader.download { fractionDownloaded, speed in
587587
fileProgress.completedUnitCount = Int64(100 * fractionDownloaded)
588+
if let speed {
589+
fileProgress.setUserInfoObject(speed, forKey: .throughputKey)
590+
progress.setUserInfoObject(speed, forKey: .throughputKey)
591+
}
588592
progressHandler(progress)
589593
}
590594
if Task.isCancelled {
@@ -598,6 +602,14 @@ public extension HubApi {
598602
return repoDestination
599603
}
600604

605+
/// New overloads exposing speed directly in the snapshot progress handler
606+
@discardableResult func snapshot(from repo: Repo, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
607+
try await snapshot(from: repo, revision: revision, matching: globs) { progress in
608+
let speed = progress.userInfo[.throughputKey] as? Double
609+
progressHandler(progress, speed)
610+
}
611+
}
612+
601613
@discardableResult
602614
func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
603615
try await snapshot(from: Repo(id: repoId), revision: revision, matching: globs, progressHandler: progressHandler)
@@ -612,6 +624,22 @@ public extension HubApi {
612624
func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress) -> Void = { _ in }) async throws -> URL {
613625
try await snapshot(from: Repo(id: repoId), revision: revision, matching: [glob], progressHandler: progressHandler)
614626
}
627+
628+
/// Convenience overloads for other snapshot entry points with speed
629+
@discardableResult
630+
func snapshot(from repoId: String, revision: String = "main", matching globs: [String] = [], progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
631+
try await snapshot(from: Repo(id: repoId), revision: revision, matching: globs, progressHandler: progressHandler)
632+
}
633+
634+
@discardableResult
635+
func snapshot(from repo: Repo, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
636+
try await snapshot(from: repo, revision: revision, matching: [glob], progressHandler: progressHandler)
637+
}
638+
639+
@discardableResult
640+
func snapshot(from repoId: String, revision: String = "main", matching glob: String, progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
641+
try await snapshot(from: Repo(id: repoId), revision: revision, matching: [glob], progressHandler: progressHandler)
642+
}
615643
}
616644

617645
/// Metadata
@@ -822,6 +850,23 @@ public extension Hub {
822850
try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler)
823851
}
824852

853+
/// Overloads exposing speed via (Progress, Double?) where Double is bytes/sec
854+
static func snapshot(from repo: Repo, matching globs: [String] = [], progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
855+
try await HubApi.shared.snapshot(from: repo, matching: globs, progressHandler: progressHandler)
856+
}
857+
858+
static func snapshot(from repoId: String, matching globs: [String] = [], progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
859+
try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: globs, progressHandler: progressHandler)
860+
}
861+
862+
static func snapshot(from repo: Repo, matching glob: String, progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
863+
try await HubApi.shared.snapshot(from: repo, matching: glob, progressHandler: progressHandler)
864+
}
865+
866+
static func snapshot(from repoId: String, matching glob: String, progressHandler: @escaping (Progress, Double?) -> Void) async throws -> URL {
867+
try await HubApi.shared.snapshot(from: Repo(id: repoId), matching: glob, progressHandler: progressHandler)
868+
}
869+
825870
static func whoami(token: String) async throws -> Config {
826871
try await HubApi(hfToken: token).whoami()
827872
}

Sources/HubCLI/HubCLI.swift

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,20 @@ struct Download: AsyncParsableCommand, SubcommandWithToken {
6363
let downloadedTo = try await hubApi.snapshot(from: repo, revision: revision, matching: include) { progress in
6464
DispatchQueue.main.async {
6565
let totalPercent = 100 * progress.fractionCompleted
66-
print("\(progress.completedUnitCount)/\(progress.totalUnitCount) \(totalPercent.formatted("%.02f"))%", terminator: "\r")
66+
let speedBps = progress.userInfo[.throughputKey] as? Double
67+
let speedString = if let s = speedBps {
68+
// Human-readable speed
69+
if s >= 1024 * 1024 {
70+
String(format: " - %.2f MB/s", s / (1024 * 1024))
71+
} else if s >= 1024 {
72+
String(format: " - %.2f KB/s", s / 1024)
73+
} else {
74+
String(format: " - %.0f B/s", s)
75+
}
76+
} else {
77+
""
78+
}
79+
print("\(progress.completedUnitCount)/\(progress.totalUnitCount) \(totalPercent.formatted("%.02f"))%\(speedString)", terminator: "\r")
6780
fflush(stdout)
6881
}
6982
}

Tests/HubTests/DownloaderTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ final class DownloaderTests: XCTestCase {
160160
switch state {
161161
case .notStarted:
162162
continue
163-
case let .downloading(progress):
163+
case let .downloading(progress, _):
164164
if threshold != 1.0, progress >= threshold {
165165
// Move to next threshold and interrupt
166166
threshold = threshold == 0.5 ? 0.75 : 1.0

Tests/HubTests/HubApiTests.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,43 @@ class SnapshotDownloadTests: XCTestCase {
11111111
)
11121112
}
11131113

1114+
func testRealDownloadWithSpeed() async throws {
1115+
// Use the DepthPro model weights file
1116+
let targetFile = "SAM 2 Studio 1.1.zip"
1117+
let repo = "coreml-projects/sam-2-studio"
1118+
let hubApi = HubApi(downloadBase: downloadDestination)
1119+
1120+
var lastSpeed: Double? = nil
1121+
1122+
// Add debug prints
1123+
print("Download destination before: \(downloadDestination.path)")
1124+
1125+
let downloadedTo = try await hubApi.snapshot(from: repo, matching: targetFile) { progress, speed in
1126+
if let speed {
1127+
print("Current speed: \(speed)")
1128+
}
1129+
1130+
lastSpeed = speed
1131+
}
1132+
1133+
// Add more debug prints
1134+
print("Downloaded to: \(downloadedTo.path)")
1135+
1136+
XCTAssertNotNil(lastSpeed)
1137+
1138+
let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo)
1139+
print("Downloaded filenames: \(downloadedFilenames)")
1140+
print("Prefix used in getRelativeFiles: \(downloadDestination.appending(path: "models/\(repo)").path)")
1141+
1142+
XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)"))
1143+
1144+
let filePath = downloadedTo.appendingPathComponent(targetFile)
1145+
XCTAssertTrue(
1146+
FileManager.default.fileExists(atPath: filePath.path),
1147+
"Downloaded file should exist at \(filePath.path)"
1148+
)
1149+
}
1150+
11141151
func testDownloadWithRevision() async throws {
11151152
let hubApi = HubApi(downloadBase: downloadDestination)
11161153
var lastProgress: Progress? = nil

0 commit comments

Comments
 (0)