Skip to content

Commit 3d4d731

Browse files
committed
Address Pouyan's feedback
1 parent 961a625 commit 3d4d731

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

nemoguardrails/cli/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,7 @@ def convert(
225225
absolute_path = os.path.abspath(path)
226226

227227
# Typer CLI args have to use an enum, not literal. Convert to Literal here
228-
from_version_literal: Literal["1.0", "2.0-alpha"] = (
229-
"1.0" if from_version == ColangVersions.one else "2.0-alpha"
230-
)
228+
from_version_literal: Literal["1.0", "2.0-alpha"] = from_version.value
231229

232230
migrate(
233231
path=absolute_path,

nemoguardrails/cli/chat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,13 @@ async def _run_chat_v1_0(
119119
] = await rails_app.generate_async(messages=history)
120120

121121
# Handle different return types from generate_async
122-
if type(response) == Tuple[Dict, Dict]:
122+
if isinstance(response, tuple) and len(response) == 2:
123123
bot_message = (
124124
response[0]
125125
if response
126126
else {"role": "assistant", "content": ""}
127127
)
128-
elif type(response) == GenerationResponse:
128+
elif isinstance(response, GenerationResponse):
129129
# GenerationResponse case
130130
response_attr = getattr(response, "response", None)
131131
if isinstance(response_attr, list) and len(response_attr) > 0:
@@ -135,7 +135,7 @@ async def _run_chat_v1_0(
135135
"role": "assistant",
136136
"content": str(response_attr),
137137
}
138-
elif type(response) == Dict:
138+
elif isinstance(response, dict):
139139
# Direct dict case
140140
bot_message = response
141141
else:

nemoguardrails/cli/debugger.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,6 @@ def get_loop_info(flow_config: FlowConfig) -> str:
181181
if order_by_name:
182182
rows.sort(key=lambda x: x[0])
183183
else:
184-
if not state:
185-
raise RuntimeError("No state available")
186184
flow_configs: Dict[str, FlowConfig] = state.flow_configs
187185
rows.sort(key=lambda x: (-flow_configs[x[0]].loop_priority, x[0]))
188186

@@ -257,7 +255,7 @@ def tree(
257255
# Convert Spec to Spec object if it's a Dict
258256
spec: Spec = (
259257
head_element_spec_op.spec
260-
if type(head_element_spec_op.spec) == SpecOp
258+
if isinstance(head_element_spec_op.spec, Spec)
261259
else Spec(**cast(Dict, head_element_spec_op.spec))
262260
)
263261

0 commit comments

Comments
 (0)