1515
1616from __future__ import annotations
1717
18- from typing import Any , List , Optional
18+ from typing import Any , List , Optional , Union , cast
1919
20- from langchain_core .language_models import BaseLanguageModel
20+ from langchain_core .language_models import BaseChatModel
21+ from langchain_core .language_models .llms import BaseLLM
2122from langchain_core .messages import AIMessage , HumanMessage , SystemMessage
2223from langchain_core .prompt_values import ChatPromptValue , StringPromptValue
2324from langchain_core .runnables import Runnable
2728
2829from nemoguardrails import LLMRails , RailsConfig
2930from nemoguardrails .integrations .langchain .utils import async_wrap
30- from nemoguardrails .rails .llm .options import GenerationOptions
31+ from nemoguardrails .rails .llm .options import GenerationOptions , GenerationResponse
3132
3233
3334class RunnableRails (Runnable [Input , Output ]):
3435 def __init__ (
3536 self ,
3637 config : RailsConfig ,
37- llm : Optional [BaseLanguageModel ] = None ,
38+ llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = None ,
3839 tools : Optional [List [Tool ]] = None ,
3940 passthrough : bool = True ,
4041 runnable : Optional [Runnable ] = None ,
@@ -67,12 +68,14 @@ def __init__(
6768 if self .passthrough_runnable :
6869 self ._init_passthrough_fn ()
6970
70- def _init_passthrough_fn (self ):
71+ def _init_passthrough_fn (self ) -> None :
7172 """Initialize the passthrough function for the LLM rails instance."""
7273
7374 async def passthrough_fn (context : dict , events : List [dict ]):
7475 # First, we fetch the input from the context
7576 _input = context .get ("passthrough_input" )
77+ if self .passthrough_runnable is None :
78+ raise ValueError ("No passthrough runnable provided" )
7679 async_wrapped_invoke = async_wrap (self .passthrough_runnable .invoke )
7780 _output = await async_wrapped_invoke (_input , self .config , ** self .kwargs )
7881
@@ -84,10 +87,11 @@ async def passthrough_fn(context: dict, events: List[dict]):
8487
8588 return text , _output
8689
87- self .rails .llm_generation_actions .passthrough_fn = passthrough_fn
90+ # Dynamically assign passthrough_fn to avoid type checker issues
91+ setattr (self .rails .llm_generation_actions , "passthrough_fn" , passthrough_fn )
8892
89- def __or__ (self , other ):
90- if isinstance (other , BaseLanguageModel ):
93+ def __or__ (self , other ) -> "RunnableRails[Input, Output]" : # type: ignore[override]
94+ if isinstance (other , ( BaseLLM , BaseChatModel ) ):
9195 self .llm = other
9296 self .rails .update_llm (other )
9397
@@ -193,6 +197,9 @@ def invoke(
193197 res = self .rails .generate (
194198 messages = input_messages , options = GenerationOptions (output_vars = True )
195199 )
200+ # When using output_vars=True, rails.generate returns a GenerationResponse
201+ if not isinstance (res , GenerationResponse ):
202+ raise Exception (f"Expected GenerationResponse, got { type (res )} " )
196203 context = res .output_data
197204 result = res .response
198205
@@ -203,17 +210,16 @@ def invoke(
203210 result = result [0 ]
204211
205212 if self .passthrough and self .passthrough_runnable :
206- passthrough_output = context .get ("passthrough_output" )
213+ passthrough_output = context .get ("passthrough_output" ) if context else None
207214
208215 # If a rail was triggered (input or dialog), the passthrough_output
209216 # will not be set. In this case, we only set the output key to the
210217 # message that was received from the guardrail configuration.
211218 if passthrough_output is None :
212- passthrough_output = {
213- self .passthrough_bot_output_key : result ["content" ]
214- }
219+ content = result .get ("content" ) if isinstance (result , dict ) else result
220+ passthrough_output = {self .passthrough_bot_output_key : content }
215221
216- bot_message = context .get ("bot_message" )
222+ bot_message = context .get ("bot_message" ) if context else None
217223
218224 # We make sure that, if the output rails altered the bot message, we
219225 # replace it in the passthrough_output
@@ -222,20 +228,28 @@ def invoke(
222228 elif isinstance (passthrough_output , dict ):
223229 passthrough_output [self .passthrough_bot_output_key ] = bot_message
224230
225- return passthrough_output
231+ return cast ( Output , passthrough_output )
226232 else :
227233 if isinstance (input , ChatPromptValue ):
228- return AIMessage (content = result ["content" ])
234+ content = result .get ("content" ) if isinstance (result , dict ) else result
235+ # Ensure content is a string for AIMessage
236+ content_str = str (content ) if content is not None else ""
237+ return cast (Output , AIMessage (content = content_str ))
229238 elif isinstance (input , StringPromptValue ):
230239 if isinstance (result , dict ):
231- return result [ "content" ]
240+ return cast ( Output , result . get ( "content" , "" ))
232241 else :
233- return result
242+ return cast ( Output , result )
234243 elif isinstance (input , dict ):
235244 user_input = input ["input" ]
236245 if isinstance (user_input , str ):
237- return {"output" : result ["content" ]}
246+ content = (
247+ result .get ("content" ) if isinstance (result , dict ) else result
248+ )
249+ return cast (Output , {"output" : content })
238250 elif isinstance (user_input , list ):
239- return {"output" : result }
251+ return cast (Output , {"output" : result })
252+ else :
253+ raise ValueError (f"Unexpected user_input type: { type (user_input )} " )
240254 else :
241255 raise ValueError (f"Unexpected input type: { type (input )} " )
0 commit comments