Skip to content

Commit 95f7faa

Browse files
authored
Merge pull request #126 from dreadnode/users/raja/fix-refresh-token-race-condition
fix: Fix S3 credential auto-refresh mechanism and extend buffer window
2 parents 0897d54 + b9c3520 commit 95f7faa

File tree

9 files changed

+259
-253
lines changed

9 files changed

+259
-253
lines changed

docs/sdk/artifact.mdx

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,7 @@ ArtifactStorage
244244
---------------
245245

246246
```python
247-
ArtifactStorage(
248-
file_system: AbstractFileSystem,
249-
credential_refresher: Callable[[], bool] | None = None,
250-
)
247+
ArtifactStorage(credential_manager: CredentialManager)
251248
```
252249

253250
Storage for artifacts with efficient handling of large files and directories.
@@ -256,35 +253,24 @@ Supports:
256253
- Content-based deduplication using SHA1 hashing
257254
- Batch uploads for directories handled by fsspec
258255

259-
Initialize artifact storage with a file system and prefix path.
256+
Initialize artifact storage with credential manager.
260257

261258
**Parameters:**
262259

263-
* **`file_system`**
264-
(`AbstractFileSystem`)
265-
–FSSpec-compatible file system
266-
* **`credential_refresher`**
267-
(`Callable[[], bool] | None`, default:
268-
`None`
269-
)
270-
–Optional function to refresh credentials when it's about to expire
260+
* **`credential_manager`**
261+
(`CredentialManager`)
262+
–Optional credential manager for S3 operations
271263

272264
<Accordion title="Source code in dreadnode/artifact/storage.py" icon="code">
273265
```python
274-
def __init__(
275-
self,
276-
file_system: fsspec.AbstractFileSystem,
277-
credential_refresher: t.Callable[[], bool] | None = None,
278-
):
266+
def __init__(self, credential_manager: CredentialManager):
279267
"""
280-
Initialize artifact storage with a file system and prefix path.
268+
Initialize artifact storage with credential manager.
281269
282270
Args:
283-
file_system: FSSpec-compatible file system
284-
credential_refresher: Optional function to refresh credentials when it's about to expire
271+
credential_manager: Optional credential manager for S3 operations
285272
"""
286-
self._file_system = file_system
287-
self._credential_refresher = credential_refresher
273+
self._credential_manager: CredentialManager = credential_manager
288274
```
289275

290276

@@ -330,23 +316,26 @@ def batch_upload_files(self, source_paths: list[str], target_paths: list[str]) -
330316
if not source_paths:
331317
return []
332318

333-
logger.debug("Batch uploading %d files", len(source_paths))
319+
def batch_upload_operation() -> list[str]:
320+
filesystem = self._credential_manager.get_filesystem()
334321

335-
srcs = []
336-
dsts = []
322+
srcs = []
323+
dsts = []
337324

338-
for src, dst in zip(source_paths, target_paths, strict=False):
339-
if not self._file_system.exists(dst):
340-
srcs.append(src)
341-
dsts.append(dst)
325+
for src, dst in zip(source_paths, target_paths, strict=False):
326+
if not filesystem.exists(dst):
327+
srcs.append(src)
328+
dsts.append(dst)
342329

343-
if srcs:
344-
self._file_system.put(srcs, dsts)
345-
logger.debug("Batch upload completed for %d files", len(srcs))
346-
else:
347-
logger.debug("All files already exist, skipping upload")
330+
if srcs:
331+
filesystem.put(srcs, dsts)
332+
logger.info("Batch upload completed for %d files", len(srcs))
333+
else:
334+
logger.info("All files already exist, skipping upload")
348335

349-
return [str(self._file_system.unstrip_protocol(target)) for target in target_paths]
336+
return [str(filesystem.unstrip_protocol(target)) for target in target_paths]
337+
338+
return self._credential_manager.execute_with_retry(batch_upload_operation)
350339
```
351340

352341

@@ -391,8 +380,9 @@ def compute_file_hash(self, file_path: Path, stream_threshold_mb: int = 10) -> s
391380
Returns:
392381
First 16 chars of SHA1 hash
393382
"""
383+
394384
file_size = file_path.stat().st_size
395-
stream_threshold = stream_threshold_mb * 1024 * 1024 # Convert MB to bytes
385+
stream_threshold = stream_threshold_mb * 1024 * 1024
396386

397387
sha1 = hashlib.sha1() # noqa: S324 # nosec
398388

@@ -478,7 +468,6 @@ Store a file in the storage system, using multipart upload for large files.
478468

479469
<Accordion title="Source code in dreadnode/artifact/storage.py" icon="code">
480470
```python
481-
@with_credential_refresh
482471
def store_file(self, file_path: Path, target_key: str) -> str:
483472
"""
484473
Store a file in the storage system, using multipart upload for large files.
@@ -490,13 +479,19 @@ def store_file(self, file_path: Path, target_key: str) -> str:
490479
Returns:
491480
Full URI with protocol to the stored file
492481
"""
493-
if not self._file_system.exists(target_key):
494-
self._file_system.put(str(file_path), target_key)
495-
logger.debug("Artifact successfully stored at %s", target_key)
496-
else:
497-
logger.debug("Artifact already exists at %s, skipping upload.", target_key)
498482

499-
return str(self._file_system.unstrip_protocol(target_key))
483+
def store_operation() -> str:
484+
filesystem = self._credential_manager.get_filesystem()
485+
486+
if not filesystem.exists(target_key):
487+
filesystem.put(str(file_path), target_key)
488+
logger.info("Artifact successfully stored at %s", target_key)
489+
else:
490+
logger.info("Artifact already exists at %s, skipping upload.", target_key)
491+
492+
return str(filesystem.unstrip_protocol(target_key))
493+
494+
return self._credential_manager.execute_with_retry(store_operation)
500495
```
501496

502497

docs/sdk/main.mdx

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,14 @@ def __init__(
5757
self.otel_scope = otel_scope
5858

5959
self._api: ApiClient | None = None
60-
60+
self._credential_manager: CredentialManager | None = None
6161
self._logfire = logfire.DEFAULT_LOGFIRE_INSTANCE
6262
self._logfire.config.ignore_no_config = True
6363

6464
self._fs: AbstractFileSystem = LocalFileSystem(auto_mkdir=True)
6565
self._fs_prefix: str = ".dreadnode/storage/"
6666

6767
self._initialized = False
68-
self._credentials: UserDataCredentials | None = None
69-
self._credentials_expiry: datetime | None = None
7068
```
7169

7270

@@ -380,9 +378,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan:
380378
return RunSpan.from_context(
381379
context=run_context,
382380
tracer=self._get_tracer(),
383-
file_system=self._fs,
384-
prefix_path=self._fs_prefix,
385-
credential_refresher=self._refresh_storage_credentials if self._credentials else None,
381+
credential_manager=self._credential_manager, # type: ignore[arg-type]
386382
)
387383
```
388384

@@ -526,19 +522,15 @@ def initialize(self) -> None:
526522
# )
527523
# )
528524
# )
529-
self._credentials = self._api.get_user_data_credentials()
530-
self._credentials_expiry = self._credentials.expiration
531-
resolved_endpoint = resolve_endpoint(self._credentials.endpoint)
532-
self._fs = S3FileSystem(
533-
key=self._credentials.access_key_id,
534-
secret=self._credentials.secret_access_key,
535-
token=self._credentials.session_token,
536-
client_kwargs={
537-
"endpoint_url": resolved_endpoint,
538-
"region_name": self._credentials.region,
539-
},
540-
)
541-
self._fs_prefix = f"{self._credentials.bucket}/{self._credentials.prefix}/"
525+
if self._api is not None:
526+
api = self._api
527+
self._credential_manager = CredentialManager(
528+
credential_fetcher=lambda: api.get_user_data_credentials()
529+
)
530+
self._credential_manager.initialize()
531+
532+
self._fs = self._credential_manager.get_filesystem()
533+
self._fs_prefix = self._credential_manager.get_prefix()
542534

543535
self._logfire = logfire.configure(
544536
local=not self.is_default,
@@ -1723,10 +1715,8 @@ def run(
17231715
tracer=self._get_tracer(),
17241716
params=params,
17251717
tags=tags,
1726-
file_system=self._fs,
1727-
prefix_path=self._fs_prefix,
1718+
credential_manager=self._credential_manager, # type: ignore[arg-type]
17281719
autolog=autolog,
1729-
credential_refresher=self._refresh_storage_credentials if self._credentials else None,
17301720
)
17311721
```
17321722

dreadnode/artifact/storage.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
"""
55

66
import hashlib
7-
import typing as t
87
from pathlib import Path
98

10-
import fsspec # type: ignore[import-untyped]
11-
12-
from dreadnode.storage_utils import with_credential_refresh
9+
from dreadnode.credential_manager import CredentialManager
1310
from dreadnode.util import logger
1411

1512
CHUNK_SIZE = 8 * 1024 * 1024 # 8MB
@@ -24,27 +21,15 @@ class ArtifactStorage:
2421
- Batch uploads for directories handled by fsspec
2522
"""
2623

27-
def __init__(
28-
self,
29-
file_system: fsspec.AbstractFileSystem,
30-
credential_refresher: t.Callable[[], bool] | None = None,
31-
):
24+
def __init__(self, credential_manager: CredentialManager):
3225
"""
33-
Initialize artifact storage with a file system and prefix path.
26+
Initialize artifact storage with credential manager.
3427
3528
Args:
36-
file_system: FSSpec-compatible file system
37-
credential_refresher: Optional function to refresh credentials when it's about to expire
29+
credential_manager: Optional credential manager for S3 operations
3830
"""
39-
self._file_system = file_system
40-
self._credential_refresher = credential_refresher
41-
42-
def _refresh_credentials_if_needed(self) -> None:
43-
"""Refresh credentials if refresher is available."""
44-
if self._credential_refresher:
45-
self._credential_refresher()
31+
self._credential_manager: CredentialManager = credential_manager
4632

47-
@with_credential_refresh
4833
def store_file(self, file_path: Path, target_key: str) -> str:
4934
"""
5035
Store a file in the storage system, using multipart upload for large files.
@@ -56,13 +41,19 @@ def store_file(self, file_path: Path, target_key: str) -> str:
5641
Returns:
5742
Full URI with protocol to the stored file
5843
"""
59-
if not self._file_system.exists(target_key):
60-
self._file_system.put(str(file_path), target_key)
61-
logger.debug("Artifact successfully stored at %s", target_key)
62-
else:
63-
logger.debug("Artifact already exists at %s, skipping upload.", target_key)
6444

65-
return str(self._file_system.unstrip_protocol(target_key))
45+
def store_operation() -> str:
46+
filesystem = self._credential_manager.get_filesystem()
47+
48+
if not filesystem.exists(target_key):
49+
filesystem.put(str(file_path), target_key)
50+
logger.info("Artifact successfully stored at %s", target_key)
51+
else:
52+
logger.info("Artifact already exists at %s, skipping upload.", target_key)
53+
54+
return str(filesystem.unstrip_protocol(target_key))
55+
56+
return self._credential_manager.execute_with_retry(store_operation)
6657

6758
def batch_upload_files(self, source_paths: list[str], target_paths: list[str]) -> list[str]:
6859
"""
@@ -78,23 +69,26 @@ def batch_upload_files(self, source_paths: list[str], target_paths: list[str]) -
7869
if not source_paths:
7970
return []
8071

81-
logger.debug("Batch uploading %d files", len(source_paths))
72+
def batch_upload_operation() -> list[str]:
73+
filesystem = self._credential_manager.get_filesystem()
8274

83-
srcs = []
84-
dsts = []
75+
srcs = []
76+
dsts = []
8577

86-
for src, dst in zip(source_paths, target_paths, strict=False):
87-
if not self._file_system.exists(dst):
88-
srcs.append(src)
89-
dsts.append(dst)
78+
for src, dst in zip(source_paths, target_paths, strict=False):
79+
if not filesystem.exists(dst):
80+
srcs.append(src)
81+
dsts.append(dst)
9082

91-
if srcs:
92-
self._file_system.put(srcs, dsts)
93-
logger.debug("Batch upload completed for %d files", len(srcs))
94-
else:
95-
logger.debug("All files already exist, skipping upload")
83+
if srcs:
84+
filesystem.put(srcs, dsts)
85+
logger.info("Batch upload completed for %d files", len(srcs))
86+
else:
87+
logger.info("All files already exist, skipping upload")
9688

97-
return [str(self._file_system.unstrip_protocol(target)) for target in target_paths]
89+
return [str(filesystem.unstrip_protocol(target)) for target in target_paths]
90+
91+
return self._credential_manager.execute_with_retry(batch_upload_operation)
9892

9993
def compute_file_hash(self, file_path: Path, stream_threshold_mb: int = 10) -> str:
10094
"""
@@ -107,8 +101,9 @@ def compute_file_hash(self, file_path: Path, stream_threshold_mb: int = 10) -> s
107101
Returns:
108102
First 16 chars of SHA1 hash
109103
"""
104+
110105
file_size = file_path.stat().st_size
111-
stream_threshold = stream_threshold_mb * 1024 * 1024 # Convert MB to bytes
106+
stream_threshold = stream_threshold_mb * 1024 * 1024
112107

113108
sha1 = hashlib.sha1() # noqa: S324 # nosec
114109

dreadnode/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,4 @@
5858
)
5959

6060
# Default values for the file system credential management
61-
FS_CREDENTIAL_REFRESH_BUFFER = 300 # 5 minutes in seconds
61+
FS_CREDENTIAL_REFRESH_BUFFER = 900 # 15 minutes in seconds

0 commit comments

Comments
 (0)