Skip to content

Commit 9dce06f

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Add rewind_async to support rewinding the session to before a previous invocation
PiperOrigin-RevId: 820552460
1 parent 307896a commit 9dce06f

File tree

11 files changed

+830
-4
lines changed

11 files changed

+830
-4
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.adk import Agent
16+
from google.adk.tools.tool_context import ToolContext
17+
from google.genai import types
18+
19+
20+
async def update_state(tool_context: ToolContext, key: str, value: str) -> dict:
21+
"""Updates a state value."""
22+
tool_context.state[key] = value
23+
return {"status": f"Updated state '{key}' to '{value}'"}
24+
25+
26+
async def load_state(tool_context: ToolContext, key: str) -> dict:
27+
"""Loads a state value."""
28+
return {key: tool_context.state.get(key)}
29+
30+
31+
async def save_artifact(
32+
tool_context: ToolContext, filename: str, content: str
33+
) -> dict:
34+
"""Saves an artifact with the given filename and content."""
35+
artifact_bytes = content.encode("utf-8")
36+
artifact_part = types.Part(
37+
inline_data=types.Blob(mime_type="text/plain", data=artifact_bytes)
38+
)
39+
version = await tool_context.save_artifact(filename, artifact_part)
40+
return {"status": "success", "filename": filename, "version": version}
41+
42+
43+
async def load_artifact(tool_context: ToolContext, filename: str) -> dict:
44+
"""Loads an artifact with the given filename."""
45+
artifact = await tool_context.load_artifact(filename)
46+
if not artifact:
47+
return {"error": f"Artifact '{filename}' not found"}
48+
content = artifact.inline_data.data.decode("utf-8")
49+
return {"filename": filename, "content": content}
50+
51+
52+
# Create the agent
53+
root_agent = Agent(
54+
name="state_agent",
55+
model="gemini-2.0-flash",
56+
instruction="""You are an agent that manages state and artifacts.
57+
58+
You can:
59+
- Update state value
60+
- Load state value
61+
- Save artifact
62+
- Load artifact
63+
64+
Use the appropriate tool based on what the user asks for.""",
65+
tools=[
66+
update_state,
67+
load_state,
68+
save_artifact,
69+
load_artifact,
70+
],
71+
)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Utility functions for handling artifact URIs."""
15+
16+
from __future__ import annotations
17+
18+
import re
19+
from typing import NamedTuple
20+
from typing import Optional
21+
22+
from google.genai import types
23+
24+
25+
class ParsedArtifactUri(NamedTuple):
26+
"""The result of parsing an artifact URI."""
27+
28+
app_name: str
29+
user_id: str
30+
session_id: Optional[str]
31+
filename: str
32+
version: int
33+
34+
35+
_SESSION_SCOPED_ARTIFACT_URI_RE = re.compile(
36+
r"artifact://apps/([^/]+)/users/([^/]+)/sessions/([^/]+)/artifacts/([^/]+)/versions/(\d+)"
37+
)
38+
_USER_SCOPED_ARTIFACT_URI_RE = re.compile(
39+
r"artifact://apps/([^/]+)/users/([^/]+)/artifacts/([^/]+)/versions/(\d+)"
40+
)
41+
42+
43+
def parse_artifact_uri(uri: str) -> Optional[ParsedArtifactUri]:
44+
"""Parses an artifact URI.
45+
46+
Args:
47+
uri: The artifact URI to parse.
48+
49+
Returns:
50+
A ParsedArtifactUri if parsing is successful, None otherwise.
51+
"""
52+
if not uri or not uri.startswith("artifact://"):
53+
return None
54+
55+
match = _SESSION_SCOPED_ARTIFACT_URI_RE.match(uri)
56+
if match:
57+
return ParsedArtifactUri(
58+
app_name=match.group(1),
59+
user_id=match.group(2),
60+
session_id=match.group(3),
61+
filename=match.group(4),
62+
version=int(match.group(5)),
63+
)
64+
65+
match = _USER_SCOPED_ARTIFACT_URI_RE.match(uri)
66+
if match:
67+
return ParsedArtifactUri(
68+
app_name=match.group(1),
69+
user_id=match.group(2),
70+
session_id=None,
71+
filename=match.group(3),
72+
version=int(match.group(4)),
73+
)
74+
75+
return None
76+
77+
78+
def get_artifact_uri(
79+
app_name: str,
80+
user_id: str,
81+
filename: str,
82+
version: int,
83+
session_id: Optional[str] = None,
84+
) -> str:
85+
"""Constructs an artifact URI.
86+
87+
Args:
88+
app_name: The name of the application.
89+
user_id: The ID of the user.
90+
filename: The name of the artifact file.
91+
version: The version of the artifact.
92+
session_id: The ID of the session.
93+
94+
Returns:
95+
The constructed artifact URI.
96+
"""
97+
if session_id:
98+
return f"artifact://apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{filename}/versions/{version}"
99+
else:
100+
return f"artifact://apps/{app_name}/users/{user_id}/artifacts/{filename}/versions/{version}"
101+
102+
103+
def is_artifact_ref(artifact: types.Part) -> bool:
104+
"""Checks if an artifact part is an artifact reference.
105+
106+
Args:
107+
artifact: The artifact part to check.
108+
109+
Returns:
110+
True if the artifact part is an artifact reference, False otherwise.
111+
"""
112+
return bool(
113+
artifact.file_data
114+
and artifact.file_data.file_uri
115+
and artifact.file_data.file_uri.startswith("artifact://")
116+
)

src/google/adk/artifacts/gcs_artifact_service.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ def _save_artifact(
212212
blob.upload_from_string(
213213
data=artifact.text,
214214
)
215+
elif artifact.file_data:
216+
raise NotImplementedError(
217+
"Saving artifact with file_data is not supported yet in"
218+
" GcsArtifactService."
219+
)
215220
else:
216221
raise ValueError("Artifact must have either inline_data or text.")
217222

src/google/adk/artifacts/in_memory_artifact_service.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any
1919
from typing import Optional
2020

21+
from google.adk.artifacts import artifact_util
2122
from google.genai import types
2223
from pydantic import BaseModel
2324
from pydantic import Field
@@ -122,7 +123,15 @@ async def save_artifact(
122123
elif artifact.text is not None:
123124
artifact_version.mime_type = "text/plain"
124125
elif artifact.file_data is not None:
125-
artifact_version.mime_type = artifact.file_data.mime_type
126+
if artifact_util.is_artifact_ref(artifact):
127+
if not artifact_util.parse_artifact_uri(artifact.file_data.file_uri):
128+
raise ValueError(
129+
f"Invalid artifact reference URI: {artifact.file_data.file_uri}"
130+
)
131+
# If it's a valid artifact URI, we store the artifact part as-is.
132+
# And we don't know the mime type until we load it.
133+
else:
134+
artifact_version.mime_type = artifact.file_data.mime_type
126135
else:
127136
raise ValueError("Not supported artifact type.")
128137

@@ -147,11 +156,42 @@ async def load_artifact(
147156
return None
148157
if version is None:
149158
version = -1
159+
150160
try:
151-
return versions[version].data
161+
artifact_entry = versions[version]
152162
except IndexError:
153163
return None
154164

165+
if artifact_entry is None:
166+
return None
167+
168+
# Resolve artifact reference if needed.
169+
artifact_data = artifact_entry.data
170+
if artifact_util.is_artifact_ref(artifact_data):
171+
parsed_uri = artifact_util.parse_artifact_uri(
172+
artifact_data.file_data.file_uri
173+
)
174+
if not parsed_uri:
175+
raise ValueError(
176+
"Invalid artifact reference URI:"
177+
f" {artifact_data.file_data.file_uri}"
178+
)
179+
return await self.load_artifact(
180+
app_name=parsed_uri.app_name,
181+
user_id=parsed_uri.user_id,
182+
filename=parsed_uri.filename,
183+
session_id=parsed_uri.session_id,
184+
version=parsed_uri.version,
185+
)
186+
187+
if (
188+
artifact_data == types.Part()
189+
or artifact_data == types.Part(text="")
190+
or (artifact_data.inline_data and not artifact_data.inline_data.data)
191+
):
192+
return None
193+
return artifact_data
194+
155195
@override
156196
async def list_artifact_keys(
157197
self, *, app_name: str, user_id: str, session_id: Optional[str] = None

src/google/adk/events/event_actions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,6 @@ class EventActions(BaseModel):
104104

105105
agent_state: Optional[dict[str, Any]] = None
106106
"""The agent state at the current event."""
107+
108+
rewind_before_invocation_id: Optional[str] = None
109+
"""The invocation id to rewind to. This is only set for rewind event."""

src/google/adk/flows/llm_flows/contents.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,30 @@ def _get_contents(
310310
accumulated_input_transcription = ''
311311
accumulated_output_transcription = ''
312312

313+
# Filter out events that are annulled by a rewind.
314+
# By iterating backward, when a rewind event is found, we skip all events
315+
# from that point back to the `rewind_before_invocation_id`, thus removing
316+
# them from the history used for the LLM request.
317+
rewind_filtered_events = []
318+
i = len(events) - 1
319+
while i >= 0:
320+
event = events[i]
321+
if event.actions and event.actions.rewind_before_invocation_id:
322+
rewind_invocation_id = event.actions.rewind_before_invocation_id
323+
for j in range(0, i, 1):
324+
if events[j].invocation_id == rewind_invocation_id:
325+
i = j
326+
break
327+
else:
328+
rewind_filtered_events.append(event)
329+
i -= 1
330+
rewind_filtered_events.reverse()
331+
313332
# Parse the events, leaving the contents and the function calls and
314333
# responses from the current agent.
315334
raw_filtered_events = []
316335
has_compaction_events = False
317-
for event in events:
336+
for event in rewind_filtered_events:
318337
if _contains_empty_content(event):
319338
continue
320339
if not _is_event_belongs_to_branch(current_branch, event):

0 commit comments

Comments
 (0)