Skip to content

Commit 49d39dc

Browse files
committed
apply review suggestions
1 parent d21b6f5 commit 49d39dc

File tree

1 file changed

+20
-63
lines changed

1 file changed

+20
-63
lines changed

tests/test_actions_llm_utils.py

Lines changed: 20 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import pytest
1617
from langchain_core.messages import AIMessage
1718

1819
from nemoguardrails.actions.llm.utils import (
@@ -27,6 +28,17 @@
2728
from nemoguardrails.context import reasoning_trace_var, tool_calls_var
2829

2930

31+
@pytest.fixture(autouse=True)
32+
def reset_context_vars():
33+
reasoning_token = reasoning_trace_var.set(None)
34+
tool_calls_token = tool_calls_var.set(None)
35+
36+
yield
37+
38+
reasoning_trace_var.reset(reasoning_token)
39+
tool_calls_var.reset(tool_calls_token)
40+
41+
3042
class MockOpenAILLM:
3143
__module__ = "langchain_openai.chat_models"
3244

@@ -222,18 +234,17 @@ def test_extract_reasoning_from_additional_kwargs_not_dict():
222234

223235

224236
def test_extract_tool_calls_from_content_blocks_single_tool_call():
225-
response = MockResponse(
226-
content_blocks=[
227-
{"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"}
228-
]
229-
)
237+
expected_tool_call = {
238+
"type": "tool_call",
239+
"name": "foo",
240+
"args": {"a": "b"},
241+
"id": "abc_123",
242+
}
243+
response = MockResponse(content_blocks=[expected_tool_call])
230244
tool_calls = _extract_tool_calls_from_content_blocks(response)
231245
assert tool_calls is not None
232246
assert len(tool_calls) == 1
233-
assert tool_calls[0]["type"] == "tool_call"
234-
assert tool_calls[0]["name"] == "foo"
235-
assert tool_calls[0]["args"] == {"a": "b"}
236-
assert tool_calls[0]["id"] == "abc_123"
247+
assert tool_calls[0] == expected_tool_call
237248

238249

239250
def test_extract_tool_calls_from_content_blocks_multiple_tool_calls():
@@ -304,8 +315,6 @@ def test_extract_tool_calls_from_attribute_no_attribute():
304315

305316

306317
def test_store_reasoning_traces_from_content_blocks():
307-
reasoning_trace_var.set(None)
308-
309318
response = MockResponse(
310319
content_blocks=[
311320
{"type": "text", "text": "The answer is 42."},
@@ -317,12 +326,8 @@ def test_store_reasoning_traces_from_content_blocks():
317326
reasoning = reasoning_trace_var.get()
318327
assert reasoning == "Let me think about this problem..."
319328

320-
reasoning_trace_var.set(None)
321-
322329

323330
def test_store_reasoning_traces_from_additional_kwargs():
324-
reasoning_trace_var.set(None)
325-
326331
response = MockResponse(
327332
additional_kwargs={"reasoning_content": "Provider specific reasoning"}
328333
)
@@ -331,12 +336,8 @@ def test_store_reasoning_traces_from_additional_kwargs():
331336
reasoning = reasoning_trace_var.get()
332337
assert reasoning == "Provider specific reasoning"
333338

334-
reasoning_trace_var.set(None)
335-
336339

337340
def test_store_reasoning_traces_prefers_content_blocks_over_additional_kwargs():
338-
reasoning_trace_var.set(None)
339-
340341
response = MockResponse(
341342
content_blocks=[
342343
{"type": "reasoning", "reasoning": "Content blocks reasoning"},
@@ -348,12 +349,8 @@ def test_store_reasoning_traces_prefers_content_blocks_over_additional_kwargs():
348349
reasoning = reasoning_trace_var.get()
349350
assert reasoning == "Content blocks reasoning"
350351

351-
reasoning_trace_var.set(None)
352-
353352

354353
def test_store_reasoning_traces_fallback_to_additional_kwargs():
355-
reasoning_trace_var.set(None)
356-
357354
response = MockResponse(
358355
content_blocks=[
359356
{"type": "text", "text": "No reasoning here"},
@@ -365,12 +362,8 @@ def test_store_reasoning_traces_fallback_to_additional_kwargs():
365362
reasoning = reasoning_trace_var.get()
366363
assert reasoning == "Fallback reasoning"
367364

368-
reasoning_trace_var.set(None)
369-
370365

371366
def test_store_reasoning_traces_no_reasoning():
372-
reasoning_trace_var.set(None)
373-
374367
response = MockResponse(
375368
content_blocks=[
376369
{"type": "text", "text": "Just text"},
@@ -383,8 +376,6 @@ def test_store_reasoning_traces_no_reasoning():
383376

384377

385378
def test_store_tool_calls_from_content_blocks():
386-
tool_calls_var.set(None)
387-
388379
response = MockResponse(
389380
content_blocks=[
390381
{"type": "text", "text": "Hello"},
@@ -410,12 +401,8 @@ def test_store_tool_calls_from_content_blocks():
410401
assert tool_calls[0]["name"] == "search"
411402
assert tool_calls[1]["name"] == "calculator"
412403

413-
tool_calls_var.set(None)
414-
415404

416405
def test_store_tool_calls_from_attribute():
417-
tool_calls_var.set(None)
418-
419406
response = MockResponse(
420407
tool_calls=[
421408
{"type": "tool_call", "name": "foo", "args": {"a": "b"}, "id": "abc_123"},
@@ -430,12 +417,8 @@ def test_store_tool_calls_from_attribute():
430417
assert tool_calls[0]["name"] == "foo"
431418
assert tool_calls[1]["name"] == "bar"
432419

433-
tool_calls_var.set(None)
434-
435420

436421
def test_store_tool_calls_prefers_content_blocks_over_attribute():
437-
tool_calls_var.set(None)
438-
439422
response = MockResponse(
440423
content_blocks=[
441424
{"type": "tool_call", "name": "from_blocks", "args": {}, "id": "1"},
@@ -451,12 +434,8 @@ def test_store_tool_calls_prefers_content_blocks_over_attribute():
451434
assert len(tool_calls) == 1
452435
assert tool_calls[0]["name"] == "from_blocks"
453436

454-
tool_calls_var.set(None)
455-
456437

457438
def test_store_tool_calls_fallback_to_attribute():
458-
tool_calls_var.set(None)
459-
460439
response = MockResponse(
461440
content_blocks=[
462441
{"type": "text", "text": "No tool calls here"},
@@ -472,12 +451,8 @@ def test_store_tool_calls_fallback_to_attribute():
472451
assert len(tool_calls) == 1
473452
assert tool_calls[0]["name"] == "fallback_tool"
474453

475-
tool_calls_var.set(None)
476-
477454

478455
def test_store_tool_calls_no_tool_calls():
479-
tool_calls_var.set(None)
480-
481456
response = MockResponse(
482457
content_blocks=[
483458
{"type": "text", "text": "Just text"},
@@ -490,8 +465,6 @@ def test_store_tool_calls_no_tool_calls():
490465

491466

492467
def test_store_reasoning_traces_with_real_aimessage_from_content_blocks():
493-
reasoning_trace_var.set(None)
494-
495468
message = AIMessage(
496469
content="The answer is 42.",
497470
additional_kwargs={"reasoning_content": "Let me think about this problem..."},
@@ -502,12 +475,8 @@ def test_store_reasoning_traces_with_real_aimessage_from_content_blocks():
502475
reasoning = reasoning_trace_var.get()
503476
assert reasoning == "Let me think about this problem..."
504477

505-
reasoning_trace_var.set(None)
506-
507478

508479
def test_store_reasoning_traces_with_real_aimessage_no_reasoning():
509-
reasoning_trace_var.set(None)
510-
511480
message = AIMessage(
512481
content="The answer is 42.",
513482
additional_kwargs={"other_field": "some value"},
@@ -520,8 +489,6 @@ def test_store_reasoning_traces_with_real_aimessage_no_reasoning():
520489

521490

522491
def test_store_tool_calls_with_real_aimessage_from_content_blocks():
523-
tool_calls_var.set(None)
524-
525492
message = AIMessage(
526493
"",
527494
tool_calls=[
@@ -539,12 +506,8 @@ def test_store_tool_calls_with_real_aimessage_from_content_blocks():
539506
assert tool_calls[0]["args"] == {"a": "b"}
540507
assert tool_calls[0]["id"] == "abc_123"
541508

542-
tool_calls_var.set(None)
543-
544509

545510
def test_store_tool_calls_with_real_aimessage_mixed_content():
546-
tool_calls_var.set(None)
547-
548511
message = AIMessage(
549512
"foo",
550513
tool_calls=[
@@ -560,12 +523,8 @@ def test_store_tool_calls_with_real_aimessage_mixed_content():
560523
assert tool_calls[0]["type"] == "tool_call"
561524
assert tool_calls[0]["name"] == "foo"
562525

563-
tool_calls_var.set(None)
564-
565526

566527
def test_store_tool_calls_with_real_aimessage_multiple_tool_calls():
567-
tool_calls_var.set(None)
568-
569528
message = AIMessage(
570529
"",
571530
tool_calls=[
@@ -581,5 +540,3 @@ def test_store_tool_calls_with_real_aimessage_multiple_tool_calls():
581540
assert len(tool_calls) == 2
582541
assert tool_calls[0]["name"] == "foo"
583542
assert tool_calls[1]["name"] == "bar"
584-
585-
tool_calls_var.set(None)

0 commit comments

Comments
 (0)