Skip to content

Commit b0a35d7

Browse files
committed
refactor(langchain): update runtime imports to langchain-core
Update Colang runtime imports to use canonical langchain-core paths for callbacks and runnables. Part of the broader migration to langchain-core for LangChain 1.x compatibility.
1 parent c13995f commit b0a35d7

File tree

10 files changed

+40
-59
lines changed

10 files changed

+40
-59
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import re
1717
from typing import Any, Dict, List, Optional, Sequence, Union
1818

19-
from langchain.base_language import BaseLanguageModel
20-
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
19+
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackManager
20+
from langchain_core.language_models import BaseLanguageModel
2121
from langchain_core.runnables import RunnableConfig
2222
from langchain_core.runnables.base import Runnable
2323

nemoguardrails/colang/v1_0/runtime/runtime.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from urllib.parse import urljoin
2323

2424
import aiohttp
25-
from langchain.chains.base import Chain
2625

2726
from nemoguardrails.actions.actions import ActionResult
2827
from nemoguardrails.actions.core import create_event
@@ -658,12 +657,6 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
658657
parameters = inspect.signature(fn).parameters
659658
action_type = "function"
660659

661-
elif isinstance(fn, Chain):
662-
# If we're dealing with a chain, we list the annotations
663-
# TODO: make some additional type checking here
664-
parameters = fn.input_keys
665-
action_type = "chain"
666-
667660
# For every parameter that start with "__context__", we pass the value
668661
for parameter_name in parameters:
669662
if parameter_name.startswith("__context__"):
@@ -677,11 +670,9 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
677670
if var_name in context:
678671
kwargs[k] = context[var_name]
679672

680-
# If we have an action server, we use it for non-system/non-chain actions
681-
if (
682-
self.config.actions_server_url
683-
and not action_meta.get("is_system_action")
684-
and action_type != "chain"
673+
# If we have an action server, we use it for non-system actions
674+
if self.config.actions_server_url and not action_meta.get(
675+
"is_system_action"
685676
):
686677
result, status = await self._get_action_resp(
687678
action_meta, action_name, kwargs

nemoguardrails/colang/v2_x/runtime/runtime.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
from urllib.parse import urljoin
2121

2222
import aiohttp
23-
import langchain
24-
from langchain.chains.base import Chain
2523

2624
from nemoguardrails.actions.actions import ActionResult
2725
from nemoguardrails.colang import parse_colang_file
@@ -45,8 +43,6 @@
4543
from nemoguardrails.rails.llm.config import RailsConfig
4644
from nemoguardrails.utils import new_event_dict, new_readable_uuid
4745

48-
langchain.debug = False
49-
5046
log = logging.getLogger(__name__)
5147

5248

@@ -202,12 +198,6 @@ async def _process_start_action(
202198
parameters = inspect.signature(fn).parameters
203199
action_type = "function"
204200

205-
elif isinstance(fn, Chain):
206-
# If we're dealing with a chain, we list the annotations
207-
# TODO: make some additional type checking here
208-
parameters = fn.input_keys
209-
action_type = "chain"
210-
211201
# For every parameter that start with "__context__", we pass the value
212202
for parameter_name in parameters:
213203
if parameter_name.startswith("__context__"):
@@ -221,11 +211,9 @@ async def _process_start_action(
221211
if var_name in context:
222212
kwargs[k] = context[var_name]
223213

224-
# If we have an action server, we use it for non-system/non-chain actions
225-
if (
226-
self.config.actions_server_url
227-
and not action_meta.get("is_system_action")
228-
and action_type != "chain"
214+
# If we have an action server, we use it for non-system actions
215+
if self.config.actions_server_url and not action_meta.get(
216+
"is_system_action"
229217
):
230218
result, status = await self._get_action_resp(
231219
action_meta, action_name, kwargs

nemoguardrails/evaluate/evaluate_factcheck.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020

2121
import tqdm
2222
import typer
23-
from langchain.chains import LLMChain
24-
from langchain.prompts import PromptTemplate
23+
from langchain_core.prompts import PromptTemplate
2524

2625
from nemoguardrails import LLMRails
2726
from nemoguardrails.actions.llm.utils import llm_call
@@ -94,19 +93,23 @@ def create_negative_samples(self, dataset):
9493
template=create_negatives_template,
9594
input_variables=["evidence", "answer"],
9695
)
97-
create_negatives_chain = LLMChain(prompt=create_negatives_prompt, llm=self.llm)
96+
97+
# Bind config parameters to the LLM for generating negative samples
98+
llm_with_config = self.llm.bind(temperature=0.8, max_tokens=300)
9899

99100
print("Creating negative samples...")
100101
for data in tqdm.tqdm(dataset):
101102
assert "evidence" in data and "question" in data and "answer" in data
102103
evidence = data["evidence"]
103104
answer = data["answer"]
104-
negative_answer_result = create_negatives_chain.invoke(
105-
{"evidence": evidence, "answer": answer},
106-
config={"temperature": 0.8, "max_tokens": 300},
105+
106+
# Format the prompt and invoke the LLM directly
107+
formatted_prompt = create_negatives_prompt.format(
108+
evidence=evidence, answer=answer
107109
)
108-
negative_answer = negative_answer_result["text"]
109-
data["incorrect_answer"] = negative_answer.strip()
110+
negative_answer = llm_with_config.invoke(formatted_prompt)
111+
negative_answer_content = negative_answer.content
112+
data["incorrect_answer"] = negative_answer_content.strip()
110113

111114
return dataset
112115

@@ -186,14 +189,16 @@ def run(self):
186189
split="negative"
187190
)
188191

189-
print(f"Positive Accuracy: {pos_num_correct/len(self.dataset) * 100}")
190-
print(f"Negative Accuracy: {neg_num_correct/len(self.dataset) * 100}")
192+
print(f"Positive Accuracy: {pos_num_correct / len(self.dataset) * 100}")
193+
print(f"Negative Accuracy: {neg_num_correct / len(self.dataset) * 100}")
191194
print(
192-
f"Overall Accuracy: {(pos_num_correct + neg_num_correct)/(2*len(self.dataset))* 100}"
195+
f"Overall Accuracy: {(pos_num_correct + neg_num_correct) / (2 * len(self.dataset)) * 100}"
193196
)
194197

195198
print("---Time taken per sample:---")
196-
print(f"Ask LLM:\t{(pos_time+neg_time)*1000/(2*len(self.dataset)):.1f}ms")
199+
print(
200+
f"Ask LLM:\t{(pos_time + neg_time) * 1000 / (2 * len(self.dataset)):.1f}ms"
201+
)
197202

198203
if self.write_outputs:
199204
dataset_name = os.path.basename(self.dataset_path).split(".")[0]

nemoguardrails/llm/providers/huggingface/pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515

1616
from typing import Any, List, Optional
1717

18-
from langchain.callbacks.manager import (
18+
from langchain_community.llms import HuggingFacePipeline
19+
from langchain_core.callbacks.manager import (
1920
AsyncCallbackManagerForLLMRun,
2021
CallbackManagerForLLMRun,
2122
)
22-
from langchain.schema.output import GenerationChunk
23-
from langchain_community.llms import HuggingFacePipeline
23+
from langchain_core.outputs import GenerationChunk
2424

2525

2626
class HuggingFacePipelineCompatible(HuggingFacePipeline):

nemoguardrails/logging/callbacks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
from typing import Any, Dict, List, Optional, Union, cast
1919
from uuid import UUID
2020

21-
from langchain.callbacks import StdOutCallbackHandler
22-
from langchain.callbacks.base import (
21+
from langchain_core.agents import AgentAction, AgentFinish
22+
from langchain_core.callbacks.base import (
2323
AsyncCallbackHandler,
2424
BaseCallbackHandler,
2525
BaseCallbackManager,
2626
)
27-
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
28-
from langchain.schema import AgentAction, AgentFinish, AIMessage, BaseMessage, LLMResult
29-
from langchain_core.outputs import ChatGeneration
27+
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
28+
from langchain_core.messages import AIMessage, BaseMessage
29+
from langchain_core.outputs import ChatGeneration, LLMResult
3030

3131
from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var
3232
from nemoguardrails.logging.explain import LLMCallInfo

nemoguardrails/streaming.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818
from typing import Any, AsyncIterator, Dict, List, Optional, Union
1919
from uuid import UUID
2020

21-
from langchain.callbacks.base import AsyncCallbackHandler
22-
from langchain.schema import BaseMessage
23-
from langchain.schema.messages import AIMessageChunk
24-
from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
21+
from langchain_core.callbacks.base import AsyncCallbackHandler
22+
from langchain_core.messages import AIMessageChunk, BaseMessage
23+
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
2524

2625
from nemoguardrails.utils import new_uuid
2726

tests/rails/llm/test_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import json
1716
from unittest.mock import MagicMock
1817

1918
import pytest
20-
from langchain.llms.base import BaseLLM
19+
from langchain_core.language_models import BaseLLM
2120
from pydantic import ValidationError
2221

2322
from nemoguardrails.rails.llm.config import Model, RailsConfig, TaskPrompt

tests/test_callbacks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@
1717
from uuid import uuid4
1818

1919
import pytest
20-
from langchain.schema import Generation, LLMResult
2120
from langchain_core.messages import (
2221
AIMessage,
2322
BaseMessage,
2423
HumanMessage,
2524
SystemMessage,
2625
ToolMessage,
2726
)
28-
from langchain_core.outputs import ChatGeneration
27+
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
2928

3029
from nemoguardrails.context import explain_info_var, llm_call_info_var, llm_stats_var
3130
from nemoguardrails.logging.callbacks import LoggingCallbackHandler

tests/test_streaming_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from uuid import UUID
2222

2323
import pytest
24-
from langchain.schema.messages import AIMessageChunk
25-
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
24+
from langchain_core.messages import AIMessageChunk
25+
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
2626

2727
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler
2828

0 commit comments

Comments
 (0)