@@ -41,7 +41,6 @@ class GcsArtifactService(BaseArtifactService):
4141 def __init__ (self , bucket_name : str , ** kwargs ):
4242 """Initializes the GcsArtifactService.
4343
44-
4544 Args:
4645 bucket_name: The name of the bucket to use.
4746 **kwargs: Keyword arguments to pass to the Google Cloud Storage client.
@@ -56,9 +55,9 @@ async def save_artifact(
5655 * ,
5756 app_name : str ,
5857 user_id : str ,
59- session_id : str ,
6058 filename : str ,
6159 artifact : types .Part ,
60+ session_id : Optional [str ] = None ,
6261 ) -> int :
6362 return await asyncio .to_thread (
6463 self ._save_artifact ,
@@ -75,8 +74,8 @@ async def load_artifact(
7574 * ,
7675 app_name : str ,
7776 user_id : str ,
78- session_id : str ,
7977 filename : str ,
78+ session_id : Optional [str ] = None ,
8079 version : Optional [int ] = None ,
8180 ) -> Optional [types .Part ]:
8281 return await asyncio .to_thread (
@@ -90,7 +89,7 @@ async def load_artifact(
9089
9190 @override
9291 async def list_artifact_keys (
93- self , * , app_name : str , user_id : str , session_id : str
92+ self , * , app_name : str , user_id : str , session_id : Optional [ str ] = None
9493 ) -> list [str ]:
9594 return await asyncio .to_thread (
9695 self ._list_artifact_keys ,
@@ -101,7 +100,12 @@ async def list_artifact_keys(
101100
102101 @override
103102 async def delete_artifact (
104- self , * , app_name : str , user_id : str , session_id : str , filename : str
103+ self ,
104+ * ,
105+ app_name : str ,
106+ user_id : str ,
107+ filename : str ,
108+ session_id : Optional [str ] = None ,
105109 ) -> None :
106110 return await asyncio .to_thread (
107111 self ._delete_artifact ,
@@ -113,7 +117,12 @@ async def delete_artifact(
113117
114118 @override
115119 async def list_versions (
116- self , * , app_name : str , user_id : str , session_id : str , filename : str
120+ self ,
121+ * ,
122+ app_name : str ,
123+ user_id : str ,
124+ filename : str ,
125+ session_id : Optional [str ] = None ,
117126 ) -> list [int ]:
118127 return await asyncio .to_thread (
119128 self ._list_versions ,
@@ -139,31 +148,36 @@ def _get_blob_name(
139148 self ,
140149 app_name : str ,
141150 user_id : str ,
142- session_id : str ,
143151 filename : str ,
144152 version : int ,
153+ session_id : Optional [str ] = None ,
145154 ) -> str :
146155 """Constructs the blob name in GCS.
147156
148157 Args:
149158 app_name: The name of the application.
150159 user_id: The ID of the user.
151- session_id: The ID of the session.
152160 filename: The name of the artifact file.
153161 version: The version of the artifact.
162+ session_id: The ID of the session.
154163
155164 Returns:
156165 The constructed blob name in GCS.
157166 """
158167 if self ._file_has_user_namespace (filename ):
159168 return f"{ app_name } /{ user_id } /user/{ filename } /{ version } "
169+
170+ if session_id is None :
171+ raise ValueError (
172+ "Session ID must be provided for session-scoped artifacts."
173+ )
160174 return f"{ app_name } /{ user_id } /{ session_id } /{ filename } /{ version } "
161175
162176 def _save_artifact (
163177 self ,
164178 app_name : str ,
165179 user_id : str ,
166- session_id : str ,
180+ session_id : Optional [ str ] ,
167181 filename : str ,
168182 artifact : types .Part ,
169183 ) -> int :
@@ -176,7 +190,7 @@ def _save_artifact(
176190 version = 0 if not versions else max (versions ) + 1
177191
178192 blob_name = self ._get_blob_name (
179- app_name , user_id , session_id , filename , version
193+ app_name , user_id , filename , version , session_id
180194 )
181195 blob = self .bucket .blob (blob_name )
182196
@@ -198,7 +212,7 @@ def _load_artifact(
198212 self ,
199213 app_name : str ,
200214 user_id : str ,
201- session_id : str ,
215+ session_id : Optional [ str ] ,
202216 filename : str ,
203217 version : Optional [int ] = None ,
204218 ) -> Optional [types .Part ]:
@@ -214,7 +228,7 @@ def _load_artifact(
214228 version = max (versions )
215229
216230 blob_name = self ._get_blob_name (
217- app_name , user_id , session_id , filename , version
231+ app_name , user_id , filename , version , session_id
218232 )
219233 blob = self .bucket .blob (blob_name )
220234
@@ -227,17 +241,18 @@ def _load_artifact(
227241 return artifact
228242
229243 def _list_artifact_keys (
230- self , app_name : str , user_id : str , session_id : str
244+ self , app_name : str , user_id : str , session_id : Optional [ str ]
231245 ) -> list [str ]:
232246 filenames = set ()
233247
234- session_prefix = f"{ app_name } /{ user_id } /{ session_id } /"
235- session_blobs = self .storage_client .list_blobs (
236- self .bucket , prefix = session_prefix
237- )
238- for blob in session_blobs :
239- * _ , filename , _ = blob .name .split ("/" )
240- filenames .add (filename )
248+ if session_id :
249+ session_prefix = f"{ app_name } /{ user_id } /{ session_id } /"
250+ session_blobs = self .storage_client .list_blobs (
251+ self .bucket , prefix = session_prefix
252+ )
253+ for blob in session_blobs :
254+ * _ , filename , _ = blob .name .split ("/" )
255+ filenames .add (filename )
241256
242257 user_namespace_prefix = f"{ app_name } /{ user_id } /user/"
243258 user_namespace_blobs = self .storage_client .list_blobs (
@@ -250,7 +265,11 @@ def _list_artifact_keys(
250265 return sorted (list (filenames ))
251266
252267 def _delete_artifact (
253- self , app_name : str , user_id : str , session_id : str , filename : str
268+ self ,
269+ app_name : str ,
270+ user_id : str ,
271+ session_id : Optional [str ],
272+ filename : str ,
254273 ) -> None :
255274 versions = self ._list_versions (
256275 app_name = app_name ,
@@ -260,18 +279,23 @@ def _delete_artifact(
260279 )
261280 for version in versions :
262281 blob_name = self ._get_blob_name (
263- app_name , user_id , session_id , filename , version
282+ app_name , user_id , filename , version , session_id
264283 )
265284 blob = self .bucket .blob (blob_name )
266285 blob .delete ()
267286 return
268287
269288 def _list_versions (
270- self , app_name : str , user_id : str , session_id : str , filename : str
289+ self ,
290+ app_name : str ,
291+ user_id : str ,
292+ session_id : Optional [str ],
293+ filename : str ,
271294 ) -> list [int ]:
272295 """Lists all available versions of an artifact.
273296
274- This method retrieves all versions of a specific artifact by querying GCS blobs
297+ This method retrieves all versions of a specific artifact by querying GCS
298+ blobs
275299 that match the constructed blob name prefix.
276300
277301 Args:
@@ -281,10 +305,11 @@ def _list_versions(
281305 filename: The name of the artifact file.
282306
283307 Returns:
284- A list of version numbers (integers) available for the specified artifact.
308+ A list of version numbers (integers) available for the specified
309+ artifact.
285310 Returns an empty list if no versions are found.
286311 """
287- prefix = self ._get_blob_name (app_name , user_id , session_id , filename , "" )
312+ prefix = self ._get_blob_name (app_name , user_id , filename , "" , session_id )
288313 blobs = self .storage_client .list_blobs (self .bucket , prefix = prefix )
289314 versions = []
290315 for blob in blobs :
0 commit comments