Skip to content

Commit a85121b

Browse files
committed
fix(cli): Fix TypeError in v2.x chat due to incorrect State/dict conversion
Fixes incorrect type assumptions in PR #1380 that caused TypeError when using v2.x chat CLI. The bug incorrectly assumed process_events_async returns dict and tried to convert State objects with State(**output_state). Changes: - Fix type annotations in LLMRails.process_events_async to return Union[dict, State] - Remove incorrect asdict() conversion and State(**) reconstruction in chat.py - Add integration tests to prevent regression
1 parent fac2774 commit a85121b

File tree

3 files changed

+261
-11
lines changed

3 files changed

+261
-11
lines changed

nemoguardrails/cli/chat.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import asyncio
1616
import json
1717
import os
18-
from dataclasses import asdict, dataclass, field
18+
from dataclasses import dataclass, field
1919
from typing import Dict, List, Optional, Tuple, Union, cast
2020

2121
import aiohttp
@@ -498,13 +498,12 @@ async def _check_local_async_actions():
498498

499499
output_events, output_state = await rails_app.process_events_async(
500500
input_events_copy,
501-
asdict(chat_state.state) if chat_state.state else None,
501+
chat_state.state,
502502
)
503503
chat_state.output_events = output_events
504504

505-
# process_events_async returns a Dict `state`, need to convert to dataclass for ChatState object
506505
if output_state:
507-
chat_state.output_state = cast(State, State(**output_state))
506+
chat_state.output_state = cast(State, output_state)
508507

509508
# Process output_events and potentially generate new input_events
510509
_process_output()
@@ -535,12 +534,11 @@ async def _process_input_events():
535534
chat_state.input_events = []
536535
output_events, output_state = await rails_app.process_events_async(
537536
input_events_copy,
538-
asdict(chat_state.state) if chat_state.state else None,
537+
chat_state.state,
539538
)
540539
chat_state.output_events = output_events
541540
if output_state:
542-
# process_events_async returns a Dict `state`, need to convert to dataclass for ChatState object
543-
output_state_typed: State = cast(State, State(**output_state))
541+
output_state_typed: State = cast(State, output_state)
544542
chat_state.output_state = output_state_typed
545543
debugger.set_output_state(output_state_typed)
546544

nemoguardrails/rails/llm/llmrails.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,9 +1463,9 @@ def generate_events(
14631463
async def process_events_async(
14641464
self,
14651465
events: List[dict],
1466-
state: Optional[dict] = None,
1466+
state: Union[Optional[dict], State] = None,
14671467
blocking: bool = False,
1468-
) -> Tuple[List[dict], dict]:
1468+
) -> Tuple[List[dict], Union[dict, State]]:
14691469
"""Process a sequence of events in a given state.
14701470
14711471
The events will be processed one by one, in the input order.
@@ -1502,9 +1502,9 @@ async def process_events_async(
15021502
def process_events(
15031503
self,
15041504
events: List[dict],
1505-
state: Optional[dict] = None,
1505+
state: Union[Optional[dict], State] = None,
15061506
blocking: bool = False,
1507-
) -> Tuple[List[dict], dict]:
1507+
) -> Tuple[List[dict], Union[dict, State]]:
15081508
"""Synchronous version of `LLMRails.process_events_async`."""
15091509

15101510
if check_sync_call_from_async_loop():
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
18+
import pytest
19+
20+
LIVE_TEST_MODE = os.environ.get("LIVE_TEST") or os.environ.get("LIVE_TEST_MODE")
21+
22+
23+
class TestProcessEventsAsyncV2x:
24+
"""Integration tests for LLMRails.process_events_async with v2.x runtime.
25+
26+
These tests would have caught issue #1505 where PR #1380 incorrectly
27+
assumed process_events_async returns dict and tried to convert with
28+
State(**output_state), causing TypeError.
29+
30+
The bug was introduced because:
31+
1. The type annotation in llmrails.py was wrong (said it returns dict)
32+
2. PR #1380 "fixed" chat.py based on the wrong annotation
33+
3. No tests called LLMRails.process_events_async() with v2.x
34+
(all tests called runtime.process_events() directly)
35+
"""
36+
37+
@pytest.mark.asyncio
38+
async def test_process_events_async_returns_state_object(self):
39+
"""Test that LLMRails.process_events_async returns State object for v2.x, not dict.
40+
41+
This is the critical test that would have caught the bug immediately.
42+
"""
43+
from nemoguardrails import LLMRails, RailsConfig
44+
from nemoguardrails.colang.v2_x.runtime.flows import State
45+
46+
config = RailsConfig.from_content(
47+
"""
48+
import core
49+
50+
flow main
51+
user said "hi"
52+
bot say "Hello!"
53+
""",
54+
"""
55+
colang_version: "2.x"
56+
models:
57+
- type: main
58+
engine: openai
59+
model: gpt-4o-mini
60+
""",
61+
)
62+
63+
rails = LLMRails(config)
64+
65+
events = [{"type": "UtteranceUserActionFinished", "final_transcript": "hi"}]
66+
67+
output_events, output_state = await rails.process_events_async(
68+
events, state=None
69+
)
70+
71+
assert isinstance(
72+
output_state, State
73+
), f"Expected State object, got {type(output_state)}"
74+
assert isinstance(output_events, list)
75+
assert len(output_events) > 0
76+
77+
@pytest.mark.asyncio
78+
async def test_process_events_async_accepts_state_object(self):
79+
"""Test that LLMRails.process_events_async accepts State object as input.
80+
81+
The bug in PR #1380 also incorrectly converted State to dict using asdict()
82+
before passing to process_events_async. This test verifies that passing
83+
State objects directly works correctly.
84+
"""
85+
from nemoguardrails import LLMRails, RailsConfig
86+
from nemoguardrails.colang.v2_x.runtime.flows import State
87+
from nemoguardrails.utils import new_event_dict
88+
89+
config = RailsConfig.from_content(
90+
"""
91+
import core
92+
93+
flow main
94+
user said "hi"
95+
bot say "Hello!"
96+
user said "bye"
97+
bot say "Goodbye!"
98+
""",
99+
"""
100+
colang_version: "2.x"
101+
models:
102+
- type: main
103+
engine: openai
104+
model: gpt-3.5-turbo
105+
""",
106+
)
107+
108+
rails = LLMRails(config)
109+
110+
events = [{"type": "UtteranceUserActionFinished", "final_transcript": "hi"}]
111+
112+
output_events_1, output_state_1 = await rails.process_events_async(
113+
events, state=None
114+
)
115+
116+
assert isinstance(output_state_1, State)
117+
118+
events_2 = []
119+
for event in output_events_1:
120+
if event["type"] == "StartUtteranceBotAction":
121+
events_2.append(
122+
new_event_dict(
123+
"UtteranceBotActionFinished",
124+
action_uid=event["action_uid"],
125+
is_success=True,
126+
final_script=event["script"],
127+
)
128+
)
129+
130+
events_2.append(
131+
{"type": "UtteranceUserActionFinished", "final_transcript": "bye"}
132+
)
133+
134+
output_events_2, output_state_2 = await rails.process_events_async(
135+
events_2, state=output_state_1
136+
)
137+
138+
assert isinstance(
139+
output_state_2, State
140+
), "Second call should also return State object when passing State as input"
141+
142+
143+
class TestChatV2xE2E:
144+
"""End-to-end tests for chat CLI with v2.x runtime.
145+
146+
These tests exercise the actual chat.py code paths that were broken by PR #1380.
147+
"""
148+
149+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
150+
@pytest.mark.asyncio
151+
async def test_chat_v2x_with_real_llm(self):
152+
"""E2E test of v2.x chat with real LLM.
153+
154+
This requires LIVE_TEST_MODE=1 and OpenAI API key.
155+
"""
156+
from unittest.mock import patch
157+
158+
from nemoguardrails import LLMRails, RailsConfig
159+
from nemoguardrails.cli.chat import _run_chat_v2_x
160+
from nemoguardrails.colang.v2_x.runtime.flows import State
161+
162+
config = RailsConfig.from_content(
163+
"""
164+
import core
165+
166+
flow main
167+
user said "hi"
168+
bot say "Hello from v2.x!"
169+
""",
170+
"""
171+
colang_version: "2.x"
172+
models:
173+
- type: main
174+
engine: openai
175+
model: gpt-3.5-turbo
176+
""",
177+
)
178+
179+
rails = LLMRails(config)
180+
181+
simulated_input = ["hi", "exit"]
182+
input_iter = iter(simulated_input)
183+
184+
def mock_input(*args, **kwargs):
185+
try:
186+
return next(input_iter)
187+
except StopIteration:
188+
raise KeyboardInterrupt()
189+
190+
with patch("builtins.input", side_effect=mock_input):
191+
try:
192+
await _run_chat_v2_x(rails)
193+
except (KeyboardInterrupt, StopIteration):
194+
pass
195+
196+
@pytest.mark.asyncio
197+
async def test_chat_v2x_process_events_flow(self):
198+
"""Test the exact code path that was broken in chat.py.
199+
200+
This simulates what _run_chat_v2_x does internally.
201+
"""
202+
from dataclasses import dataclass, field
203+
from typing import List, Optional
204+
205+
from nemoguardrails import LLMRails, RailsConfig
206+
from nemoguardrails.colang.v2_x.runtime.flows import State
207+
208+
@dataclass
209+
class ChatState:
210+
state: Optional[State] = None
211+
input_events: List[dict] = field(default_factory=list)
212+
output_events: List[dict] = field(default_factory=list)
213+
output_state: Optional[State] = None
214+
215+
config = RailsConfig.from_content(
216+
"""
217+
import core
218+
219+
flow main
220+
user said "hi"
221+
bot say "Hello!"
222+
""",
223+
"""
224+
colang_version: "2.x"
225+
models:
226+
- type: main
227+
engine: openai
228+
model: gpt-3.5-turbo
229+
""",
230+
)
231+
232+
rails = LLMRails(config)
233+
chat_state = ChatState()
234+
235+
chat_state.input_events = [
236+
{"type": "UtteranceUserActionFinished", "final_transcript": "hi"}
237+
]
238+
239+
input_events_copy = chat_state.input_events.copy()
240+
chat_state.input_events = []
241+
242+
output_events, output_state = await rails.process_events_async(
243+
input_events_copy,
244+
chat_state.state,
245+
)
246+
chat_state.output_events = output_events
247+
248+
if output_state:
249+
chat_state.output_state = output_state
250+
251+
assert isinstance(chat_state.output_state, State)
252+
assert len(chat_state.output_events) > 0

0 commit comments

Comments
 (0)