Skip to content

Commit bdb0de6

Browse files
authored
Update conversation history handling (#24)
* Update conversation history handling * support using previous_response_id
1 parent 0999b93 commit bdb0de6

File tree

17 files changed

+946
-384
lines changed

17 files changed

+946
-384
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,6 @@ env/
147147

148148
# Python package management
149149
uv.lock
150+
151+
# Internal files
152+
internal_examples/

examples/basic/agents_sdk.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
InputGuardrailTripwireTriggered,
88
OutputGuardrailTripwireTriggered,
99
Runner,
10+
SQLiteSession,
1011
)
1112
from agents.run import RunConfig
1213

@@ -50,6 +51,9 @@
5051

5152
async def main() -> None:
5253
"""Main input loop for the customer support agent with input/output guardrails."""
54+
# Create a session for the agent to store the conversation history
55+
session = SQLiteSession("guardrails-session")
56+
5357
# Create agent with guardrails automatically configured from pipeline configuration
5458
AGENT = GuardrailAgent(
5559
config=PIPELINE_CONFIG,
@@ -65,6 +69,7 @@ async def main() -> None:
6569
AGENT,
6670
user_input,
6771
run_config=RunConfig(tracing_disabled=True),
72+
session=session,
6873
)
6974
print(f"Assistant: {result.final_output}")
7075
except EOFError:

src/guardrails/_base_client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .runtime import load_pipeline_bundles
2020
from .types import GuardrailLLMContextProto, GuardrailResult
2121
from .utils.context import validate_guardrail_context
22+
from .utils.conversation import append_assistant_response, normalize_conversation
2223

2324
logger = logging.getLogger(__name__)
2425

@@ -257,6 +258,18 @@ def _instantiate_all_guardrails(self) -> dict[str, list]:
257258
guardrails[stage_name] = instantiate_guardrails(stage, default_spec_registry) if stage else []
258259
return guardrails
259260

261+
def _normalize_conversation(self, payload: Any) -> list[dict[str, Any]]:
262+
"""Normalize arbitrary conversation payloads."""
263+
return normalize_conversation(payload)
264+
265+
def _conversation_with_response(
266+
self,
267+
conversation: list[dict[str, Any]],
268+
response: Any,
269+
) -> list[dict[str, Any]]:
270+
"""Append the assistant response to a normalized conversation."""
271+
return append_assistant_response(conversation, response)
272+
260273
def _validate_context(self, context: Any) -> None:
261274
"""Validate context against all guardrails."""
262275
for stage_guardrails in self.guardrails.values():

src/guardrails/_streaming.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ._base_client import GuardrailsResponse
1414
from .exceptions import GuardrailTripwireTriggered
1515
from .types import GuardrailResult
16+
from .utils.conversation import merge_conversation_with_items
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -25,6 +26,7 @@ async def _stream_with_guardrails(
2526
llm_stream: Any, # coroutine or async iterator of OpenAI chunks
2627
preflight_results: list[GuardrailResult],
2728
input_results: list[GuardrailResult],
29+
conversation_history: list[dict[str, Any]] | None = None,
2830
check_interval: int = 100,
2931
suppress_tripwire: bool = False,
3032
) -> AsyncIterator[GuardrailsResponse]:
@@ -46,7 +48,16 @@ async def _stream_with_guardrails(
4648
# Run output guardrails periodically
4749
if chunk_count % check_interval == 0:
4850
try:
49-
await self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire)
51+
history = merge_conversation_with_items(
52+
conversation_history or [],
53+
[{"role": "assistant", "content": accumulated_text}],
54+
)
55+
await self._run_stage_guardrails(
56+
"output",
57+
accumulated_text,
58+
conversation_history=history,
59+
suppress_tripwire=suppress_tripwire,
60+
)
5061
except GuardrailTripwireTriggered:
5162
# Clear accumulated output and re-raise
5263
accumulated_text = ""
@@ -57,7 +68,16 @@ async def _stream_with_guardrails(
5768

5869
# Final output check
5970
if accumulated_text:
60-
await self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire)
71+
history = merge_conversation_with_items(
72+
conversation_history or [],
73+
[{"role": "assistant", "content": accumulated_text}],
74+
)
75+
await self._run_stage_guardrails(
76+
"output",
77+
accumulated_text,
78+
conversation_history=history,
79+
suppress_tripwire=suppress_tripwire,
80+
)
6181
# Note: This final result won't be yielded since stream is complete
6282
# but the results are available in the last chunk
6383

@@ -66,6 +86,7 @@ def _stream_with_guardrails_sync(
6686
llm_stream: Any, # iterator of OpenAI chunks
6787
preflight_results: list[GuardrailResult],
6888
input_results: list[GuardrailResult],
89+
conversation_history: list[dict[str, Any]] | None = None,
6990
check_interval: int = 100,
7091
suppress_tripwire: bool = False,
7192
):
@@ -83,7 +104,16 @@ def _stream_with_guardrails_sync(
83104
# Run output guardrails periodically
84105
if chunk_count % check_interval == 0:
85106
try:
86-
self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire)
107+
history = merge_conversation_with_items(
108+
conversation_history or [],
109+
[{"role": "assistant", "content": accumulated_text}],
110+
)
111+
self._run_stage_guardrails(
112+
"output",
113+
accumulated_text,
114+
conversation_history=history,
115+
suppress_tripwire=suppress_tripwire,
116+
)
87117
except GuardrailTripwireTriggered:
88118
# Clear accumulated output and re-raise
89119
accumulated_text = ""
@@ -94,6 +124,15 @@ def _stream_with_guardrails_sync(
94124

95125
# Final output check
96126
if accumulated_text:
97-
self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire)
127+
history = merge_conversation_with_items(
128+
conversation_history or [],
129+
[{"role": "assistant", "content": accumulated_text}],
130+
)
131+
self._run_stage_guardrails(
132+
"output",
133+
accumulated_text,
134+
conversation_history=history,
135+
suppress_tripwire=suppress_tripwire,
136+
)
98137
# Note: This final result won't be yielded since stream is complete
99138
# but the results are available in the last chunk

0 commit comments

Comments
 (0)