Skip to content

Commit 0e3c0f7

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Make session_id optional in BaseArtifactService methods
PiperOrigin-RevId: 816782982
1 parent f2bed14 commit 0e3c0f7

File tree

4 files changed

+126
-55
lines changed

4 files changed

+126
-55
lines changed

src/google/adk/artifacts/base_artifact_service.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
from __future__ import annotations
1515

1616
from abc import ABC
1717
from abc import abstractmethod
@@ -29,9 +29,9 @@ async def save_artifact(
2929
*,
3030
app_name: str,
3131
user_id: str,
32-
session_id: str,
3332
filename: str,
3433
artifact: types.Part,
34+
session_id: Optional[str] = None,
3535
) -> int:
3636
"""Saves an artifact to the artifact service storage.
3737
@@ -42,9 +42,9 @@ async def save_artifact(
4242
Args:
4343
app_name: The app name.
4444
user_id: The user ID.
45-
session_id: The session ID.
4645
filename: The filename of the artifact.
4746
artifact: The artifact to save.
47+
session_id: The session ID. If `None`, the artifact is user-scoped.
4848
4949
Returns:
5050
The revision ID. The first version of the artifact has a revision ID of 0.
@@ -57,8 +57,8 @@ async def load_artifact(
5757
*,
5858
app_name: str,
5959
user_id: str,
60-
session_id: str,
6160
filename: str,
61+
session_id: Optional[str] = None,
6262
version: Optional[int] = None,
6363
) -> Optional[types.Part]:
6464
"""Gets an artifact from the artifact service storage.
@@ -69,8 +69,8 @@ async def load_artifact(
6969
Args:
7070
app_name: The app name.
7171
user_id: The user ID.
72-
session_id: The session ID.
7372
filename: The filename of the artifact.
73+
session_id: The session ID. If `None`, load the user-scoped artifact.
7474
version: The version of the artifact. If None, the latest version will be
7575
returned.
7676
@@ -80,7 +80,7 @@ async def load_artifact(
8080

8181
@abstractmethod
8282
async def list_artifact_keys(
83-
self, *, app_name: str, user_id: str, session_id: str
83+
self, *, app_name: str, user_id: str, session_id: Optional[str] = None
8484
) -> list[str]:
8585
"""Lists all the artifact filenames within a session.
8686
@@ -90,33 +90,48 @@ async def list_artifact_keys(
9090
session_id: The ID of the session.
9191
9292
Returns:
93-
A list of all artifact filenames within a session.
93+
A list of artifact filenames. If `session_id` is provided, returns
94+
both session-scoped and user-scoped artifact filenames. If `session_id`
95+
is `None`, returns
96+
user-scoped artifact filenames.
9497
"""
9598

9699
@abstractmethod
97100
async def delete_artifact(
98-
self, *, app_name: str, user_id: str, session_id: str, filename: str
101+
self,
102+
*,
103+
app_name: str,
104+
user_id: str,
105+
filename: str,
106+
session_id: Optional[str] = None,
99107
) -> None:
100108
"""Deletes an artifact.
101109
102110
Args:
103111
app_name: The name of the application.
104112
user_id: The ID of the user.
105-
session_id: The ID of the session.
106113
filename: The name of the artifact file.
114+
session_id: The ID of the session. If `None`, delete the user-scoped
115+
artifact.
107116
"""
108117

109118
@abstractmethod
110119
async def list_versions(
111-
self, *, app_name: str, user_id: str, session_id: str, filename: str
120+
self,
121+
*,
122+
app_name: str,
123+
user_id: str,
124+
filename: str,
125+
session_id: Optional[str] = None,
112126
) -> list[int]:
113127
"""Lists all versions of an artifact.
114128
115129
Args:
116130
app_name: The name of the application.
117131
user_id: The ID of the user.
118-
session_id: The ID of the session.
119132
filename: The name of the artifact file.
133+
session_id: The ID of the session. If `None`, only list the user-scoped
134+
artifacts versions.
120135
121136
Returns:
122137
A list of all available versions of the artifact.

src/google/adk/artifacts/gcs_artifact_service.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ class GcsArtifactService(BaseArtifactService):
4141
def __init__(self, bucket_name: str, **kwargs):
4242
"""Initializes the GcsArtifactService.
4343
44-
4544
Args:
4645
bucket_name: The name of the bucket to use.
4746
**kwargs: Keyword arguments to pass to the Google Cloud Storage client.
@@ -56,9 +55,9 @@ async def save_artifact(
5655
*,
5756
app_name: str,
5857
user_id: str,
59-
session_id: str,
6058
filename: str,
6159
artifact: types.Part,
60+
session_id: Optional[str] = None,
6261
) -> int:
6362
return await asyncio.to_thread(
6463
self._save_artifact,
@@ -75,8 +74,8 @@ async def load_artifact(
7574
*,
7675
app_name: str,
7776
user_id: str,
78-
session_id: str,
7977
filename: str,
78+
session_id: Optional[str] = None,
8079
version: Optional[int] = None,
8180
) -> Optional[types.Part]:
8281
return await asyncio.to_thread(
@@ -90,7 +89,7 @@ async def load_artifact(
9089

9190
@override
9291
async def list_artifact_keys(
93-
self, *, app_name: str, user_id: str, session_id: str
92+
self, *, app_name: str, user_id: str, session_id: Optional[str] = None
9493
) -> list[str]:
9594
return await asyncio.to_thread(
9695
self._list_artifact_keys,
@@ -101,7 +100,12 @@ async def list_artifact_keys(
101100

102101
@override
103102
async def delete_artifact(
104-
self, *, app_name: str, user_id: str, session_id: str, filename: str
103+
self,
104+
*,
105+
app_name: str,
106+
user_id: str,
107+
filename: str,
108+
session_id: Optional[str] = None,
105109
) -> None:
106110
return await asyncio.to_thread(
107111
self._delete_artifact,
@@ -113,7 +117,12 @@ async def delete_artifact(
113117

114118
@override
115119
async def list_versions(
116-
self, *, app_name: str, user_id: str, session_id: str, filename: str
120+
self,
121+
*,
122+
app_name: str,
123+
user_id: str,
124+
filename: str,
125+
session_id: Optional[str] = None,
117126
) -> list[int]:
118127
return await asyncio.to_thread(
119128
self._list_versions,
@@ -139,31 +148,36 @@ def _get_blob_name(
139148
self,
140149
app_name: str,
141150
user_id: str,
142-
session_id: str,
143151
filename: str,
144152
version: int,
153+
session_id: Optional[str] = None,
145154
) -> str:
146155
"""Constructs the blob name in GCS.
147156
148157
Args:
149158
app_name: The name of the application.
150159
user_id: The ID of the user.
151-
session_id: The ID of the session.
152160
filename: The name of the artifact file.
153161
version: The version of the artifact.
162+
session_id: The ID of the session.
154163
155164
Returns:
156165
The constructed blob name in GCS.
157166
"""
158167
if self._file_has_user_namespace(filename):
159168
return f"{app_name}/{user_id}/user/{filename}/{version}"
169+
170+
if session_id is None:
171+
raise ValueError(
172+
"Session ID must be provided for session-scoped artifacts."
173+
)
160174
return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
161175

162176
def _save_artifact(
163177
self,
164178
app_name: str,
165179
user_id: str,
166-
session_id: str,
180+
session_id: Optional[str],
167181
filename: str,
168182
artifact: types.Part,
169183
) -> int:
@@ -176,7 +190,7 @@ def _save_artifact(
176190
version = 0 if not versions else max(versions) + 1
177191

178192
blob_name = self._get_blob_name(
179-
app_name, user_id, session_id, filename, version
193+
app_name, user_id, filename, version, session_id
180194
)
181195
blob = self.bucket.blob(blob_name)
182196

@@ -198,7 +212,7 @@ def _load_artifact(
198212
self,
199213
app_name: str,
200214
user_id: str,
201-
session_id: str,
215+
session_id: Optional[str],
202216
filename: str,
203217
version: Optional[int] = None,
204218
) -> Optional[types.Part]:
@@ -214,7 +228,7 @@ def _load_artifact(
214228
version = max(versions)
215229

216230
blob_name = self._get_blob_name(
217-
app_name, user_id, session_id, filename, version
231+
app_name, user_id, filename, version, session_id
218232
)
219233
blob = self.bucket.blob(blob_name)
220234

@@ -227,17 +241,18 @@ def _load_artifact(
227241
return artifact
228242

229243
def _list_artifact_keys(
230-
self, app_name: str, user_id: str, session_id: str
244+
self, app_name: str, user_id: str, session_id: Optional[str]
231245
) -> list[str]:
232246
filenames = set()
233247

234-
session_prefix = f"{app_name}/{user_id}/{session_id}/"
235-
session_blobs = self.storage_client.list_blobs(
236-
self.bucket, prefix=session_prefix
237-
)
238-
for blob in session_blobs:
239-
*_, filename, _ = blob.name.split("/")
240-
filenames.add(filename)
248+
if session_id:
249+
session_prefix = f"{app_name}/{user_id}/{session_id}/"
250+
session_blobs = self.storage_client.list_blobs(
251+
self.bucket, prefix=session_prefix
252+
)
253+
for blob in session_blobs:
254+
*_, filename, _ = blob.name.split("/")
255+
filenames.add(filename)
241256

242257
user_namespace_prefix = f"{app_name}/{user_id}/user/"
243258
user_namespace_blobs = self.storage_client.list_blobs(
@@ -250,7 +265,11 @@ def _list_artifact_keys(
250265
return sorted(list(filenames))
251266

252267
def _delete_artifact(
253-
self, app_name: str, user_id: str, session_id: str, filename: str
268+
self,
269+
app_name: str,
270+
user_id: str,
271+
session_id: Optional[str],
272+
filename: str,
254273
) -> None:
255274
versions = self._list_versions(
256275
app_name=app_name,
@@ -260,18 +279,23 @@ def _delete_artifact(
260279
)
261280
for version in versions:
262281
blob_name = self._get_blob_name(
263-
app_name, user_id, session_id, filename, version
282+
app_name, user_id, filename, version, session_id
264283
)
265284
blob = self.bucket.blob(blob_name)
266285
blob.delete()
267286
return
268287

269288
def _list_versions(
270-
self, app_name: str, user_id: str, session_id: str, filename: str
289+
self,
290+
app_name: str,
291+
user_id: str,
292+
session_id: Optional[str],
293+
filename: str,
271294
) -> list[int]:
272295
"""Lists all available versions of an artifact.
273296
274-
This method retrieves all versions of a specific artifact by querying GCS blobs
297+
This method retrieves all versions of a specific artifact by querying GCS
298+
blobs
275299
that match the constructed blob name prefix.
276300
277301
Args:
@@ -281,10 +305,11 @@ def _list_versions(
281305
filename: The name of the artifact file.
282306
283307
Returns:
284-
A list of version numbers (integers) available for the specified artifact.
308+
A list of version numbers (integers) available for the specified
309+
artifact.
285310
Returns an empty list if no versions are found.
286311
"""
287-
prefix = self._get_blob_name(app_name, user_id, session_id, filename, "")
312+
prefix = self._get_blob_name(app_name, user_id, filename, "", session_id)
288313
blobs = self.storage_client.list_blobs(self.bucket, prefix=prefix)
289314
versions = []
290315
for blob in blobs:

0 commit comments

Comments
 (0)