1111from langchain .llms .base import LLM
1212from langchain .prompts import PromptTemplate
1313from langchain .schema .runnable import RunnableMap , RunnablePassthrough
14- from langchain .load import dumpd
1514from ads .llm .guardrails import HuggingFaceEvaluation
1615from ads .llm .guardrails .base import BlockedByGuardrail , GuardrailIO
1716from ads .llm .chain import GuardrailSequence
18- from ads .llm .load import load
17+ from ads .llm .serialize import load , dump
1918
2019
2120class FakeLLM (LLM ):
@@ -64,7 +63,7 @@ def test_toxicity_without_threshold(self):
6463 chain = self .FAKE_LLM | toxicity
6564 output = chain .invoke (self .TOXIC_CONTENT )
6665 self .assertEqual (output , self .TOXIC_CONTENT )
67- serialized = dumpd (chain )
66+ serialized = dump (chain )
6867 chain = load (serialized , valid_namespaces = ["tests" ])
6968 output = chain .invoke (self .TOXIC_CONTENT )
7069 self .assertEqual (output , self .TOXIC_CONTENT )
@@ -77,7 +76,7 @@ def test_toxicity_with_threshold(self):
7776 chain = self .FAKE_LLM | toxicity
7877 with self .assertRaises (BlockedByGuardrail ):
7978 chain .invoke (self .TOXIC_CONTENT )
80- serialized = dumpd (chain )
79+ serialized = dump (chain )
8180 chain = load (serialized , valid_namespaces = ["tests" ])
8281 with self .assertRaises (BlockedByGuardrail ):
8382 chain .invoke (self .TOXIC_CONTENT )
@@ -94,7 +93,7 @@ def test_toxicity_without_exception(self):
9493 chain = self .FAKE_LLM | toxicity
9594 output = chain .invoke (self .TOXIC_CONTENT )
9695 self .assertEqual (output , toxicity .custom_msg )
97- serialized = dumpd (chain )
96+ serialized = dump (chain )
9897 chain = load (serialized , valid_namespaces = ["tests" ])
9998 output = chain .invoke (self .TOXIC_CONTENT )
10099 self .assertEqual (output , toxicity .custom_msg )
@@ -109,7 +108,7 @@ def test_toxicity_return_metrics(self):
109108 self .assertIsInstance (output , dict )
110109 self .assertEqual (output ["output" ], self .TOXIC_CONTENT )
111110 self .assertGreater (output ["metrics" ]["toxicity" ][0 ], 0.2 )
112- serialized = dumpd (chain )
111+ serialized = dump (chain )
113112 chain = load (serialized , valid_namespaces = ["tests" ])
114113 output = chain .invoke (self .TOXIC_CONTENT )
115114 self .assertIsInstance (output , dict )
@@ -123,9 +122,11 @@ class GuardrailSequenceTests(GuardrailTestsBase):
123122 def test_guardrail_sequence_with_template_and_toxicity (self ):
124123 template = PromptTemplate .from_template ("Tell me a joke about {subject}" )
125124 map_input = RunnableMap (subject = RunnablePassthrough ())
126- toxicity = HuggingFaceEvaluation (path = "toxicity" , load_args = self .LOAD_ARGS )
125+ toxicity = HuggingFaceEvaluation (
126+ path = "toxicity" , load_args = self .LOAD_ARGS , select = min
127+ )
127128 chain = GuardrailSequence .from_sequence (
128129 map_input | template | self .FAKE_LLM | toxicity
129130 )
130- output = chain .run ("cats" )
131+ output = chain .run ("cats" , num_generations = 5 )
131132 self .assertIsInstance (output , GuardrailIO )
0 commit comments