Skip to content

Commit df51265

Browse files
authored
chore(types): Type-clean /cli (37 errors) (#1380)
* Cleaned nemoguardrails/cli * Address Pouyan's feedback * Add nemoguardrails/cli to pyright pre-commit checking
1 parent f6c5830 commit df51265

File tree

5 files changed

+188
-79
lines changed

5 files changed

+188
-79
lines changed

nemoguardrails/cli/__init__.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
import logging
1818
import os
19-
from typing import List, Optional
19+
from enum import Enum
20+
from typing import Any, List, Literal, Optional
2021

2122
import typer
2223
import uvicorn
@@ -27,13 +28,24 @@
2728
from nemoguardrails.cli.chat import run_chat
2829
from nemoguardrails.cli.migration import migrate
2930
from nemoguardrails.cli.providers import _list_providers, select_provider_with_type
30-
from nemoguardrails.eval import cli
31+
from nemoguardrails.eval import cli as eval_cli
3132
from nemoguardrails.logging.verbose import set_verbose
3233
from nemoguardrails.utils import init_random_seed
3334

35+
36+
class ColangVersions(str, Enum):
37+
one = "1.0"
38+
two_alpha = "2.0-alpha"
39+
40+
41+
_COLANG_VERSIONS = [version.value for version in ColangVersions]
42+
43+
3444
app = typer.Typer()
3545

36-
app.add_typer(cli.app, name="eval", short_help="Evaluation a guardrail configuration.")
46+
app.add_typer(
47+
eval_cli.app, name="eval", short_help="Evaluation a guardrail configuration."
48+
)
3749
app.pretty_exceptions_enable = False
3850

3951
logging.getLogger().setLevel(logging.WARNING)
@@ -44,7 +56,8 @@ def chat(
4456
config: List[str] = typer.Option(
4557
default=["config"],
4658
exists=True,
47-
help="Path to a directory containing configuration files to use. Can also point to a single configuration file.",
59+
help="Path to a directory containing configuration files to use. "
60+
"Can also point to a single configuration file.",
4861
),
4962
verbose: bool = typer.Option(
5063
default=False,
@@ -60,7 +73,8 @@ def chat(
6073
),
6174
debug_level: List[str] = typer.Option(
6275
default=[],
63-
help="Enable debug mode which prints rich information about the flows execution. Available levels: WARNING, INFO, DEBUG",
76+
help="Enable debug mode which prints rich information about the flows execution. "
77+
"Available levels: WARNING, INFO, DEBUG",
6478
),
6579
streaming: bool = typer.Option(
6680
default=False,
@@ -77,7 +91,7 @@ def chat(
7791
):
7892
"""Start an interactive chat session."""
7993
if len(config) > 1:
80-
typer.secho(f"Multiple configurations are not supported.", fg=typer.colors.RED)
94+
typer.secho("Multiple configurations are not supported.", fg=typer.colors.RED)
8195
typer.echo("Please provide a single folder.")
8296
raise typer.Exit(1)
8397

@@ -143,23 +157,27 @@ def server(
143157
if config:
144158
# We make sure there is no trailing separator, as that might break things in
145159
# single config mode.
146-
api.app.rails_config_path = os.path.expanduser(config[0].rstrip(os.path.sep))
160+
setattr(
161+
api.app,
162+
"rails_config_path",
163+
os.path.expanduser(config[0].rstrip(os.path.sep)),
164+
)
147165
else:
148166
# If we don't have a config, we try to see if there is a local config folder
149167
local_path = os.getcwd()
150168
local_configs_path = os.path.join(local_path, "config")
151169

152170
if os.path.exists(local_configs_path):
153-
api.app.rails_config_path = local_configs_path
171+
setattr(api.app, "rails_config_path", local_configs_path)
154172

155173
if verbose:
156174
logging.getLogger().setLevel(logging.INFO)
157175

158176
if disable_chat_ui:
159-
api.app.disable_chat_ui = True
177+
setattr(api.app, "disable_chat_ui", True)
160178

161179
if auto_reload:
162-
api.app.auto_reload = True
180+
setattr(api.app, "auto_reload", True)
163181

164182
if prefix:
165183
server_app = FastAPI()
@@ -173,17 +191,14 @@ def server(
173191
uvicorn.run(server_app, port=port, log_level="info", host="0.0.0.0")
174192

175193

176-
_AVAILABLE_OPTIONS = ["1.0", "2.0-alpha"]
177-
178-
179194
@app.command()
180195
def convert(
181196
path: str = typer.Argument(
182197
..., help="The path to the file or directory to migrate."
183198
),
184-
from_version: str = typer.Option(
185-
default="1.0",
186-
help=f"The version of the colang files to migrate from. Available options: {_AVAILABLE_OPTIONS}.",
199+
from_version: ColangVersions = typer.Option(
200+
default=ColangVersions.one,
201+
help=f"The version of the colang files to migrate from. Available options: {_COLANG_VERSIONS}.",
187202
),
188203
verbose: bool = typer.Option(
189204
default=False,
@@ -209,11 +224,14 @@ def convert(
209224

210225
absolute_path = os.path.abspath(path)
211226

227+
# Typer CLI args have to use an enum, not literal. Convert to Literal here
228+
from_version_literal: Literal["1.0", "2.0-alpha"] = from_version.value
229+
212230
migrate(
213231
path=absolute_path,
214232
include_main_flow=include_main_flow,
215233
use_active_decorator=use_active_decorator,
216-
from_version=from_version,
234+
from_version=from_version_literal,
217235
validate=validate,
218236
)
219237

nemoguardrails/cli/chat.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import asyncio
1616
import json
1717
import os
18-
from dataclasses import dataclass, field
19-
from typing import Dict, List, Optional, cast
18+
from dataclasses import asdict, dataclass, field
19+
from typing import Dict, List, Optional, Tuple, Union, cast
2020

2121
import aiohttp
2222
from prompt_toolkit import HTML, PromptSession
@@ -30,7 +30,11 @@
3030
from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x
3131
from nemoguardrails.logging import verbose
3232
from nemoguardrails.logging.verbose import console
33-
from nemoguardrails.streaming import StreamingHandler
33+
from nemoguardrails.rails.llm.options import (
34+
GenerationLog,
35+
GenerationOptions,
36+
GenerationResponse,
37+
)
3438
from nemoguardrails.utils import get_or_create_event_loop, new_event_dict, new_uuid
3539

3640
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -61,6 +65,8 @@ async def _run_chat_v1_0(
6165
)
6266

6367
if not server_url:
68+
if config_path is None:
69+
raise RuntimeError("config_path cannot be None when server_url is None")
6470
rails_config = RailsConfig.from_path(config_path)
6571
rails_app = LLMRails(rails_config, verbose=verbose)
6672
if streaming and not rails_config.streaming_supported:
@@ -82,7 +88,12 @@ async def _run_chat_v1_0(
8288

8389
if not server_url:
8490
# If we have streaming from a locally loaded config, we initialize the handler.
85-
if streaming and not server_url and rails_app.main_llm_supports_streaming:
91+
if (
92+
streaming
93+
and not server_url
94+
and rails_app
95+
and rails_app.main_llm_supports_streaming
96+
):
8697
bot_message_list = []
8798
async for chunk in rails_app.stream_async(messages=history):
8899
if '{"event": "ABORT"' in chunk:
@@ -101,11 +112,40 @@ async def _run_chat_v1_0(
101112
bot_message = {"role": "assistant", "content": bot_message_text}
102113

103114
else:
104-
bot_message = await rails_app.generate_async(messages=history)
115+
if rails_app is None:
116+
raise RuntimeError("Rails App is None")
117+
response: Union[
118+
str, Dict, GenerationResponse, Tuple[Dict, Dict]
119+
] = await rails_app.generate_async(messages=history)
120+
121+
# Handle different return types from generate_async
122+
if isinstance(response, tuple) and len(response) == 2:
123+
bot_message = (
124+
response[0]
125+
if response
126+
else {"role": "assistant", "content": ""}
127+
)
128+
elif isinstance(response, GenerationResponse):
129+
# GenerationResponse case
130+
response_attr = getattr(response, "response", None)
131+
if isinstance(response_attr, list) and len(response_attr) > 0:
132+
bot_message = response_attr[0]
133+
else:
134+
bot_message = {
135+
"role": "assistant",
136+
"content": str(response_attr),
137+
}
138+
elif isinstance(response, dict):
139+
# Direct dict case
140+
bot_message = response
141+
else:
142+
# String or other fallback case
143+
bot_message = {"role": "assistant", "content": str(response)}
105144

106145
if not streaming or not rails_app.main_llm_supports_streaming:
107146
# We print bot messages in green.
108-
console.print("[green]" + f"{bot_message['content']}" + "[/]")
147+
content = bot_message.get("content", str(bot_message))
148+
console.print("[green]" + f"{content}" + "[/]")
109149
else:
110150
data = {
111151
"config_id": config_id,
@@ -116,19 +156,19 @@ async def _run_chat_v1_0(
116156
async with session.post(
117157
f"{server_url}/v1/chat/completions",
118158
json=data,
119-
) as response:
159+
) as http_response:
120160
# If the response is streaming, we show each chunk as it comes
121-
if response.headers.get("Transfer-Encoding") == "chunked":
161+
if http_response.headers.get("Transfer-Encoding") == "chunked":
122162
bot_message_text = ""
123-
async for chunk in response.content.iter_any():
124-
chunk = chunk.decode("utf-8")
163+
async for chunk_bytes in http_response.content.iter_any():
164+
chunk = chunk_bytes.decode("utf-8")
125165
console.print("[green]" + f"{chunk}" + "[/]", end="")
126166
bot_message_text += chunk
127167
console.print("")
128168

129169
bot_message = {"role": "assistant", "content": bot_message_text}
130170
else:
131-
result = await response.json()
171+
result = await http_response.json()
132172
bot_message = result["messages"][0]
133173

134174
# We print bot messages in green.
@@ -297,7 +337,8 @@ def _process_output():
297337
else:
298338
console.print(
299339
"[black on magenta]"
300-
+ f"scene information (start): (title={event['title']}, action_uid={event['action_uid']}, content={event['content']})"
340+
+ f"scene information (start): (title={event['title']}, "
341+
+ f"action_uid={event['action_uid']}, content={event['content']})"
301342
+ "[/]"
302343
)
303344

@@ -333,7 +374,8 @@ def _process_output():
333374
else:
334375
console.print(
335376
"[black on magenta]"
336-
+ f"scene form (start): (prompt={event['prompt']}, action_uid={event['action_uid']}, inputs={event['inputs']})"
377+
+ f"scene form (start): (prompt={event['prompt']}, "
378+
+ f"action_uid={event['action_uid']}, inputs={event['inputs']})"
337379
+ "[/]"
338380
)
339381
chat_state.input_events.append(
@@ -370,7 +412,8 @@ def _process_output():
370412
else:
371413
console.print(
372414
"[black on magenta]"
373-
+ f"scene choice (start): (prompt={event['prompt']}, action_uid={event['action_uid']}, options={event['options']})"
415+
+ f"scene choice (start): (prompt={event['prompt']}, "
416+
+ f"action_uid={event['action_uid']}, options={event['options']})"
374417
+ "[/]"
375418
)
376419
chat_state.input_events.append(
@@ -452,12 +495,16 @@ async def _check_local_async_actions():
452495
# We need to copy input events to prevent race condition
453496
input_events_copy = chat_state.input_events.copy()
454497
chat_state.input_events = []
455-
(
456-
chat_state.output_events,
457-
chat_state.output_state,
458-
) = await rails_app.process_events_async(
459-
input_events_copy, chat_state.state
498+
499+
output_events, output_state = await rails_app.process_events_async(
500+
input_events_copy,
501+
asdict(chat_state.state) if chat_state.state else None,
460502
)
503+
chat_state.output_events = output_events
504+
505+
# process_events_async returns a Dict `state`, need to convert to dataclass for ChatState object
506+
if output_state:
507+
chat_state.output_state = cast(State, State(**output_state))
461508

462509
# Process output_events and potentially generate new input_events
463510
_process_output()
@@ -470,7 +517,8 @@ async def _check_local_async_actions():
470517
# If there are no pending actions, we stop
471518
check_task.cancel()
472519
check_task = None
473-
debugger.set_output_state(chat_state.output_state)
520+
if chat_state.output_state is not None:
521+
debugger.set_output_state(chat_state.output_state)
474522
chat_state.status.stop()
475523
enable_input.set()
476524
return
@@ -485,13 +533,16 @@ async def _process_input_events():
485533
# We need to copy input events to prevent race condition
486534
input_events_copy = chat_state.input_events.copy()
487535
chat_state.input_events = []
488-
(
489-
chat_state.output_events,
490-
chat_state.output_state,
491-
) = await rails_app.process_events_async(
492-
input_events_copy, chat_state.state
536+
output_events, output_state = await rails_app.process_events_async(
537+
input_events_copy,
538+
asdict(chat_state.state) if chat_state.state else None,
493539
)
494-
debugger.set_output_state(chat_state.output_state)
540+
chat_state.output_events = output_events
541+
if output_state:
542+
# process_events_async returns a Dict `state`, need to convert to dataclass for ChatState object
543+
output_state_typed: State = cast(State, State(**output_state))
544+
chat_state.output_state = output_state_typed
545+
debugger.set_output_state(output_state_typed)
495546

496547
_process_output()
497548
# If we don't have a check task, we start it
@@ -653,6 +704,8 @@ def run_chat(
653704
server_url (Optional[str]): The URL of the chat server. Defaults to None.
654705
config_id (Optional[str]): The configuration ID. Defaults to None.
655706
"""
707+
if config_path is None:
708+
raise RuntimeError("config_path cannot be None")
656709
rails_config = RailsConfig.from_path(config_path)
657710

658711
if verbose and verbose_llm_calls:

0 commit comments

Comments
 (0)