1313from ._base_client import GuardrailsResponse
1414from .exceptions import GuardrailTripwireTriggered
1515from .types import GuardrailResult
16+ from .utils .conversation import merge_conversation_with_items
1617
1718logger = logging .getLogger (__name__ )
1819
@@ -25,6 +26,7 @@ async def _stream_with_guardrails(
2526 llm_stream : Any , # coroutine or async iterator of OpenAI chunks
2627 preflight_results : list [GuardrailResult ],
2728 input_results : list [GuardrailResult ],
29+ conversation_history : list [dict [str , Any ]] | None = None ,
2830 check_interval : int = 100 ,
2931 suppress_tripwire : bool = False ,
3032 ) -> AsyncIterator [GuardrailsResponse ]:
@@ -46,7 +48,16 @@ async def _stream_with_guardrails(
4648 # Run output guardrails periodically
4749 if chunk_count % check_interval == 0 :
4850 try :
49- await self ._run_stage_guardrails ("output" , accumulated_text , suppress_tripwire = suppress_tripwire )
51+ history = merge_conversation_with_items (
52+ conversation_history or [],
53+ [{"role" : "assistant" , "content" : accumulated_text }],
54+ )
55+ await self ._run_stage_guardrails (
56+ "output" ,
57+ accumulated_text ,
58+ conversation_history = history ,
59+ suppress_tripwire = suppress_tripwire ,
60+ )
5061 except GuardrailTripwireTriggered :
5162 # Clear accumulated output and re-raise
5263 accumulated_text = ""
@@ -57,7 +68,16 @@ async def _stream_with_guardrails(
5768
5869 # Final output check
5970 if accumulated_text :
60- await self ._run_stage_guardrails ("output" , accumulated_text , suppress_tripwire = suppress_tripwire )
71+ history = merge_conversation_with_items (
72+ conversation_history or [],
73+ [{"role" : "assistant" , "content" : accumulated_text }],
74+ )
75+ await self ._run_stage_guardrails (
76+ "output" ,
77+ accumulated_text ,
78+ conversation_history = history ,
79+ suppress_tripwire = suppress_tripwire ,
80+ )
6181 # Note: This final result won't be yielded since stream is complete
6282 # but the results are available in the last chunk
6383
@@ -66,6 +86,7 @@ def _stream_with_guardrails_sync(
6686 llm_stream : Any , # iterator of OpenAI chunks
6787 preflight_results : list [GuardrailResult ],
6888 input_results : list [GuardrailResult ],
89+ conversation_history : list [dict [str , Any ]] | None = None ,
6990 check_interval : int = 100 ,
7091 suppress_tripwire : bool = False ,
7192 ):
@@ -83,7 +104,16 @@ def _stream_with_guardrails_sync(
83104 # Run output guardrails periodically
84105 if chunk_count % check_interval == 0 :
85106 try :
86- self ._run_stage_guardrails ("output" , accumulated_text , suppress_tripwire = suppress_tripwire )
107+ history = merge_conversation_with_items (
108+ conversation_history or [],
109+ [{"role" : "assistant" , "content" : accumulated_text }],
110+ )
111+ self ._run_stage_guardrails (
112+ "output" ,
113+ accumulated_text ,
114+ conversation_history = history ,
115+ suppress_tripwire = suppress_tripwire ,
116+ )
87117 except GuardrailTripwireTriggered :
88118 # Clear accumulated output and re-raise
89119 accumulated_text = ""
@@ -94,6 +124,15 @@ def _stream_with_guardrails_sync(
94124
95125 # Final output check
96126 if accumulated_text :
97- self ._run_stage_guardrails ("output" , accumulated_text , suppress_tripwire = suppress_tripwire )
127+ history = merge_conversation_with_items (
128+ conversation_history or [],
129+ [{"role" : "assistant" , "content" : accumulated_text }],
130+ )
131+ self ._run_stage_guardrails (
132+ "output" ,
133+ accumulated_text ,
134+ conversation_history = history ,
135+ suppress_tripwire = suppress_tripwire ,
136+ )
98137 # Note: This final result won't be yielded since stream is complete
99138 # but the results are available in the last chunk
0 commit comments