Skip to content

Commit 5c6ca51

Browse files
committed
feat: support langchain v1 (#1472)
* refactor(langchain): migrate imports to canonical langchain-core paths Migrate all imports from deprecated proxy paths to canonical langchain-core paths to ensure compatibility with LangChain 1.x. Changes include: - Use `from langchain_core.language_models import BaseLLM, BaseChatModel` - Remove proxy imports from `langchain.chat_models.base` - Standardize submodule imports to top-level langchain_core.language_models This ensures forward compatibility with LangChain 1.x where proxy imports from the main langchain package will be removed. * refactor(langchain)!: remove deprecated Chain support from action dispatcher Remove support for registering LangChain Chain objects as actions in favor of the modern Runnable interface. Chain support is deprecated in LangChain 1.x and users should migrate to using Runnable objects instead. - Remove Chain handling logic from action_dispatcher.py - Remove Chain-based tests from test_runnable_rails.py - Add deprecation warning in python-api.md documentation * refactor(langchain)!: remove SummarizeDocument built-in action Remove the built-in SummarizeDocument action which relied on deprecated LangChain Chain features. Users who need document summarization should implement custom actions using LangChain Runnable chains. - Delete nemoguardrails/actions/summarize_document.py - Remove related import from llm/filters.py * feat(langchain): add LangChain 1.x compatibility with fallback patterns Add try/except fallback patterns in examples to support both LangChain 0.x and 1.x. When using LangChain 1.x, legacy Chain features are imported from langchain-classic package with helpful error messages. This allows examples to work seamlessly across LangChain versions without requiring code changes from users. - Add fallback imports for RetrievalQA, embeddings, text splitters, vectorstores - Provide clear error messages directing users to install langchain-classic * refactor(langchain): update runtime imports to langchain-core Update Colang runtime imports to use canonical langchain-core paths for callbacks and runnables. Part of the broader migration to langchain-core for LangChain 1.x compatibility. * docs(langchain): rewrite custom LLM provider guide with BaseChatModel support Complete rewrite of the custom LLM provider documentation with: - Separate comprehensive guides for BaseLLM (text completion) and BaseChatModel (chat) - Correct method signatures (_call vs _generate) - Proper async implementations - Clear registration instructions (register_llm_provider vs register_chat_provider) - Working code examples with correct langchain-core imports - Important notes on choosing the right base class This addresses the gap where users were not properly guided on implementing custom chat models and were being directed to the wrong interface. * feat(langchain): extend dependency constraints to support LangChain 1.x Extend LangChain dependency constraints to support both 0.x and 1.x versions: - langchain: >=0.2.14,<0.4.0 → >=0.2.14,<2.0.0 - langchain-core: >=0.2.14,<0.4.0 → >=0.2.14,<2.0.0 - langchain-community: >=0.2.5,<0.4.0 → >=0.2.5,<2.0.0
1 parent 3b4204a commit 5c6ca51

File tree

51 files changed

+331
-350
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+331
-350
lines changed

docs/user-guides/configuration-guide/custom-initialization.md

Lines changed: 135 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,56 +37,65 @@ def init(app: LLMRails):
3737

3838
## Custom LLM Provider Registration
3939

40-
To register a custom LLM provider, you need to create a class that inherits from `BaseLanguageModel` and register it using `register_llm_provider`.
40+
NeMo Guardrails supports two types of custom LLM providers:
41+
1. **Text Completion Models** (`BaseLLM`) - For models that work with string prompts
42+
2. **Chat Models** (`BaseChatModel`) - For models that work with message-based conversations
4143

42-
It is important to implement the following methods:
44+
### Custom Text Completion LLM (BaseLLM)
4345

44-
**Required**:
46+
To register a custom text completion LLM provider, create a class that inherits from `BaseLLM` and register it using `register_llm_provider`.
4547

46-
- `_call`
47-
- `_llm_type`
48+
**Required methods:**
49+
- `_call` - Synchronous text completion
50+
- `_llm_type` - Returns the LLM type identifier
4851

49-
**Optional**:
50-
51-
- `_acall`
52-
- `_astream`
53-
- `_stream`
54-
- `_identifying_params`
55-
56-
In other words, to create your custom LLM provider, you need to implement the following interface methods: `_call`, `_llm_type`, and optionally `_acall`, `_astream`, `_stream`, and `_identifying_params`. Here's how you can do it:
52+
**Optional methods:**
53+
- `_acall` - Asynchronous text completion (recommended)
54+
- `_stream` - Streaming text completion
55+
- `_astream` - Async streaming text completion
56+
- `_identifying_params` - Returns parameters for model identification
5757

5858
```python
5959
from typing import Any, Iterator, List, Optional
6060
61-
from langchain.base_language import BaseLanguageModel
6261
from langchain_core.callbacks.manager import (
63-
CallbackManagerForLLMRun,
6462
AsyncCallbackManagerForLLMRun,
63+
CallbackManagerForLLMRun,
6564
)
65+
from langchain_core.language_models import BaseLLM
6666
from langchain_core.outputs import GenerationChunk
6767
6868
from nemoguardrails.llm.providers import register_llm_provider
6969
7070
71-
class MyCustomLLM(BaseLanguageModel):
71+
class MyCustomTextLLM(BaseLLM):
72+
"""Custom text completion LLM."""
73+
74+
@property
75+
def _llm_type(self) -> str:
76+
return "custom_text_llm"
7277
7378
def _call(
7479
self,
7580
prompt: str,
7681
stop: Optional[List[str]] = None,
7782
run_manager: Optional[CallbackManagerForLLMRun] = None,
78-
**kwargs,
83+
**kwargs: Any,
7984
) -> str:
80-
pass
85+
"""Synchronous text completion."""
86+
# Your implementation here
87+
return "Generated text response"
8188
8289
async def _acall(
8390
self,
8491
prompt: str,
8592
stop: Optional[List[str]] = None,
8693
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
87-
**kwargs,
94+
**kwargs: Any,
8895
) -> str:
89-
pass
96+
"""Asynchronous text completion (recommended)."""
97+
# Your async implementation here
98+
return "Generated text response"
9099
91100
def _stream(
92101
self,
@@ -95,22 +104,122 @@ class MyCustomLLM(BaseLanguageModel):
95104
run_manager: Optional[CallbackManagerForLLMRun] = None,
96105
**kwargs: Any,
97106
) -> Iterator[GenerationChunk]:
98-
pass
107+
"""Optional: Streaming text completion."""
108+
# Yield chunks of text
109+
yield GenerationChunk(text="chunk1")
110+
yield GenerationChunk(text="chunk2")
111+
112+
113+
register_llm_provider("custom_text_llm", MyCustomTextLLM)
114+
```
115+
116+
### Custom Chat Model (BaseChatModel)
117+
118+
To register a custom chat model, create a class that inherits from `BaseChatModel` and register it using `register_chat_provider`.
119+
120+
**Required methods:**
121+
- `_generate` - Synchronous chat completion
122+
- `_llm_type` - Returns the LLM type identifier
123+
124+
**Optional methods:**
125+
- `_agenerate` - Asynchronous chat completion (recommended)
126+
- `_stream` - Streaming chat completion
127+
- `_astream` - Async streaming chat completion
128+
129+
```python
130+
from typing import Any, Iterator, List, Optional
131+
132+
from langchain_core.callbacks.manager import (
133+
AsyncCallbackManagerForLLMRun,
134+
CallbackManagerForLLMRun,
135+
)
136+
from langchain_core.language_models import BaseChatModel
137+
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
138+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
139+
140+
from nemoguardrails.llm.providers import register_chat_provider
141+
142+
143+
class MyCustomChatModel(BaseChatModel):
144+
"""Custom chat model."""
99145
100-
# rest of the implementation
101-
...
146+
@property
147+
def _llm_type(self) -> str:
148+
return "custom_chat_model"
102149
103-
register_llm_provider("custom_llm", MyCustomLLM)
150+
def _generate(
151+
self,
152+
messages: List[BaseMessage],
153+
stop: Optional[List[str]] = None,
154+
run_manager: Optional[CallbackManagerForLLMRun] = None,
155+
**kwargs: Any,
156+
) -> ChatResult:
157+
"""Synchronous chat completion."""
158+
# Convert messages to your model's format and generate response
159+
response_text = "Generated chat response"
160+
161+
message = AIMessage(content=response_text)
162+
generation = ChatGeneration(message=message)
163+
return ChatResult(generations=[generation])
164+
165+
async def _agenerate(
166+
self,
167+
messages: List[BaseMessage],
168+
stop: Optional[List[str]] = None,
169+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
170+
**kwargs: Any,
171+
) -> ChatResult:
172+
"""Asynchronous chat completion (recommended)."""
173+
# Your async implementation
174+
response_text = "Generated chat response"
175+
176+
message = AIMessage(content=response_text)
177+
generation = ChatGeneration(message=message)
178+
return ChatResult(generations=[generation])
179+
180+
def _stream(
181+
self,
182+
messages: List[BaseMessage],
183+
stop: Optional[List[str]] = None,
184+
run_manager: Optional[CallbackManagerForLLMRun] = None,
185+
**kwargs: Any,
186+
) -> Iterator[ChatGenerationChunk]:
187+
"""Optional: Streaming chat completion."""
188+
# Yield chunks
189+
chunk = ChatGenerationChunk(message=AIMessageChunk(content="chunk1"))
190+
yield chunk
191+
192+
193+
register_chat_provider("custom_chat_model", MyCustomChatModel)
104194
```
105195

106-
You can then use the custom LLM provider in your configuration:
196+
### Using Custom LLM Providers
197+
198+
After registering your custom provider, you can use it in your configuration:
107199

108200
```yaml
109201
models:
110202
- type: main
111-
engine: custom_llm
203+
engine: custom_text_llm # or custom_chat_model
112204
```
113205

206+
### Important Notes
207+
208+
1. **Import from langchain-core:** Always import base classes from `langchain_core.language_models`:
209+
```python
210+
from langchain_core.language_models import BaseLLM, BaseChatModel
211+
```
212+
213+
2. **Implement async methods:** For better performance, always implement `_acall` (for BaseLLM) or `_agenerate` (for BaseChatModel).
214+
215+
3. **Choose the right base class:**
216+
- Use `BaseLLM` for text completion models (prompt → text)
217+
- Use `BaseChatModel` for chat models (messages → message)
218+
219+
4. **Registration functions:**
220+
- Use `register_llm_provider()` for `BaseLLM` subclasses
221+
- Use `register_chat_provider()` for `BaseChatModel` subclasses
222+
114223
## Custom Embedding Provider Registration
115224

116225
You can also register a custom embedding provider by using the `LLMRails.register_embedding_provider` function.

docs/user-guides/python-api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ For convenience, this toolkit also includes a selection of LangChain tools, wrap
132132

133133
### Chains as Actions
134134

135+
> **⚠️ DEPRECATED**: Chain support is deprecated and will be removed in a future release. Please use [Runnable](https://python.langchain.com/docs/expression_language/) instead. See the [Runnable as Action Guide](langchain/runnable-as-action/README.md) for examples.
136+
135137
You can register a Langchain chain as an action using the [LLMRails.register_action](../api/nemoguardrails.rails.llm.llmrails.md#method-llmrailsregister_action) method:
136138

137139
```python

examples/configs/rag/custom_rag_output_rails/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from langchain.prompts import PromptTemplate
17-
from langchain_core.language_models.llms import BaseLLM
16+
from langchain_core.language_models import BaseLLM
1817
from langchain_core.output_parsers import StrOutputParser
18+
from langchain_core.prompts import PromptTemplate
1919

2020
from nemoguardrails import LLMRails
2121
from nemoguardrails.actions.actions import ActionResult

examples/configs/rag/multi_kb/config.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,24 @@
2121
import pandas as pd
2222
import torch
2323
from gpt4pandas import GPT4Pandas
24-
from langchain.chains import RetrievalQA
25-
from langchain.embeddings import HuggingFaceEmbeddings
26-
from langchain.text_splitter import CharacterTextSplitter
27-
from langchain.vectorstores import FAISS
28-
from langchain_core.language_models.llms import BaseLLM
24+
25+
try:
26+
from langchain.chains import RetrievalQA
27+
from langchain.embeddings import HuggingFaceEmbeddings
28+
from langchain.text_splitter import CharacterTextSplitter
29+
from langchain.vectorstores import FAISS
30+
except ImportError:
31+
try:
32+
from langchain_classic.chains import RetrievalQA
33+
from langchain_classic.embeddings import HuggingFaceEmbeddings
34+
from langchain_classic.text_splitter import CharacterTextSplitter
35+
from langchain_classic.vectorstores import FAISS
36+
except ImportError as e:
37+
raise ImportError(
38+
"Failed to import required LangChain modules. "
39+
"For LangChain >=1.0.0, install langchain-classic: pip install langchain-classic"
40+
) from e
41+
from langchain_core.language_models import BaseLLM
2942
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3043

3144
from nemoguardrails import LLMRails, RailsConfig

examples/configs/rag/multi_kb/tabular_llm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import asyncio
1716
from typing import Any, Dict, List, Optional
1817

19-
from langchain.callbacks.manager import (
18+
from langchain_core.callbacks.manager import (
2019
AsyncCallbackManagerForLLMRun,
2120
CallbackManagerForLLMRun,
2221
)
23-
from langchain.llms.base import LLM
22+
from langchain_core.language_models import BaseLLM
2423

2524

2625
def query_tabular_data(usr_query: str, gpt: any, raw_data_frame: any):
@@ -58,7 +57,7 @@ def query_tabular_data(usr_query: str, gpt: any, raw_data_frame: any):
5857
return out, d2.to_string()
5958

6059

61-
class TabularLLM(LLM):
60+
class TabularLLM(BaseLLM):
6261
"""LLM wrapping for GPT4Pandas."""
6362

6463
model: str = ""

examples/configs/rag/pinecone/config.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,19 @@
1818
from typing import Optional
1919

2020
import pinecone
21-
from langchain.chains import RetrievalQA
22-
from langchain.docstore.document import Document
23-
from langchain.embeddings.openai import OpenAIEmbeddings
24-
from langchain.vectorstores import Pinecone
25-
from langchain_core.language_models.llms import BaseLLM
21+
22+
try:
23+
from langchain.chains import RetrievalQA
24+
from langchain.embeddings.openai import OpenAIEmbeddings
25+
from langchain.vectorstores import Pinecone
26+
except ImportError as e:
27+
raise ImportError(
28+
"Failed to import required LangChain modules. "
29+
"Ensure you have installed the correct version of langchain and its dependencies. "
30+
f"Original error: {e}"
31+
) from e
32+
33+
from langchain_core.language_models import BaseLLM
2634

2735
from nemoguardrails import LLMRails
2836
from nemoguardrails.actions import action
@@ -31,7 +39,6 @@
3139

3240
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
3341
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
34-
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT")
3542
index_name = "nemoguardrailsindex"
3643

3744
LOG_FILENAME = datetime.now().strftime("logs/mylogfile_%H_%M_%d_%m_%Y.log")

examples/scripts/demo_llama_index_guardrails.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from typing import Any, Callable, Coroutine
1717

18-
from langchain_core.language_models.llms import BaseLLM
18+
from langchain_core.language_models import BaseLLM
1919

2020
from nemoguardrails import LLMRails, RailsConfig
2121

examples/scripts/langchain/experiments.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,16 @@
1515

1616
import os
1717

18-
from langchain.chains import LLMMathChain
19-
from langchain.prompts import ChatPromptTemplate
18+
try:
19+
from langchain.chains import LLMMathChain
20+
except ImportError as e:
21+
raise ImportError(
22+
"Failed to import required LangChain modules. "
23+
"If you're using LangChain >= 1.0.0, ensure langchain-classic is installed. "
24+
f"Original error: {e}"
25+
) from e
26+
27+
from langchain_core.prompts import ChatPromptTemplate
2028
from langchain_core.tools import Tool
2129
from langchain_openai.chat_models import ChatOpenAI
2230
from pydantic import BaseModel, Field

nemoguardrails/actions/action_dispatcher.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,10 @@
2323
from pathlib import Path
2424
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
2525

26-
from langchain.chains.base import Chain
2726
from langchain_core.runnables import Runnable
2827

2928
from nemoguardrails import utils
3029
from nemoguardrails.actions.llm.utils import LLMCallException
31-
from nemoguardrails.logging.callbacks import logging_callbacks
3230

3331
log = logging.getLogger(__name__)
3432

@@ -228,27 +226,6 @@ async def execute_action(
228226
f"Synchronous action `{action_name}` has been called."
229227
)
230228

231-
elif isinstance(fn, Chain):
232-
try:
233-
chain = fn
234-
235-
# For chains with only one output key, we use the `arun` function
236-
# to return directly the result.
237-
if len(chain.output_keys) == 1:
238-
result = await chain.arun(
239-
**params, callbacks=logging_callbacks
240-
)
241-
else:
242-
# Otherwise, we return the dict with the output keys.
243-
result = await chain.acall(
244-
inputs=params,
245-
return_only_outputs=True,
246-
callbacks=logging_callbacks,
247-
)
248-
except NotImplementedError:
249-
# Not ideal, but for now we fall back to sync execution
250-
# if the async is not available
251-
result = fn.run(**params)
252229
elif isinstance(fn, Runnable):
253230
# If it's a Runnable, we invoke it as well
254231
runnable = fn

0 commit comments

Comments
 (0)