Skip to content

Commit 2d773cc

Browse files
authored
chore(types): Type-clean llm/ (27 errors) (#1394)
* Cleaned llm/ type errors * Add nemoguardrails/llm to the pyright pre-commit check * Fix types in nemoguardrails/rails module * Use poetry install --all-extras --with dev to install langchain_nvidia_ai_endpoints for Github CI tests * Install extras in test-coverage-report so the langchain_nvidia_ai_endpoints work for pyright type-checking * Remove tritonclient from type-checking (should this be deprecated? * Add upgrade-deps to the full-tests.yml file in Github CI/CD * Exclude providers/trtllm/** and providers/_langchain_nvidia_ai_endpoints_patch.py from type-checking * Roll back type cleaning under llm/providers/trtllm now they're excluded from type-checking * Type-clean the LFU cache implementation * Address Pouyan's feedback. Removed Model.model Optional and default value * fix typo * Revert github workflow changes (not needed now we exclude trtllm from type-checking) * Remove comment from pyproject.toml * Revert mandatory Model name field change, add None-guard back * Address last feedback
1 parent d2d41f4 commit 2d773cc

File tree

10 files changed

+115
-47
lines changed

10 files changed

+115
-47
lines changed

nemoguardrails/llm/cache/lfu.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def append(self, node: LFUNode) -> None:
5454
"""Add node to the end of the list (before tail)."""
5555
node.prev = self.tail.prev
5656
node.next = self.tail
57-
self.tail.prev.next = node
57+
if self.tail.prev is not None:
58+
self.tail.prev.next = node
5859
self.tail.prev = node
5960
self.size += 1
6061

@@ -67,8 +68,10 @@ def pop(self, node: Optional[LFUNode] = None) -> Optional[LFUNode]:
6768
node = self.head.next
6869

6970
# Remove node from the list
70-
node.prev.next = node.next
71-
node.next.prev = node.prev
71+
if node is not None and node.prev is not None:
72+
node.prev.next = node.next
73+
if node is not None and node.next is not None:
74+
node.next.prev = node.prev
7275
self.size -= 1
7376

7477
return node
@@ -121,6 +124,7 @@ def __init__(
121124
"evictions": 0,
122125
"puts": 0,
123126
"updates": 0,
127+
"hit_rate": 0.0,
124128
}
125129

126130
def _update_node_freq(self, node: LFUNode) -> None:
@@ -288,6 +292,7 @@ def reset_stats(self) -> None:
288292
"evictions": 0,
289293
"puts": 0,
290294
"updates": 0,
295+
"hit_rate": 0.0,
291296
}
292297

293298
def _check_and_log_stats(self) -> None:

nemoguardrails/llm/filters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def to_messages(colang_history: str) -> List[dict]:
140140
# a message from the user, and the rest gets translated to messages from the assistant.
141141
lines = colang_history.split("\n")
142142

143-
bot_lines = []
143+
bot_lines: list[str] = []
144144
for i, line in enumerate(lines):
145145
if line.startswith('user "'):
146146
# If we have bot lines in the buffer, we first add a bot message.
@@ -181,8 +181,8 @@ def to_messages_v2(colang_history: str) -> List[dict]:
181181
# a message from the user, and the rest gets translated to messages from the assistant.
182182
lines = colang_history.split("\n")
183183

184-
user_lines = []
185-
bot_lines = []
184+
user_lines: list[str] = []
185+
bot_lines: list[str] = []
186186
for line in lines:
187187
if line.startswith("user action:"):
188188
if len(bot_lines) > 0:
@@ -275,7 +275,7 @@ def verbose_v1(colang_history: str) -> str:
275275
return "\n".join(lines)
276276

277277

278-
def to_chat_messages(events: List[dict]) -> str:
278+
def to_chat_messages(events: List[dict]) -> List[dict]:
279279
"""Filter that turns an array of events into a sequence of user/assistant messages.
280280
281281
Properly handles multimodal content by preserving the structure when the content

nemoguardrails/llm/helpers.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import List, Optional, Type, Union
16+
from typing import List, Optional, Type
1717

1818
from langchain.callbacks.manager import (
1919
AsyncCallbackManagerForLLMRun,
2020
CallbackManagerForLLMRun,
2121
)
22-
from langchain_core.language_models.llms import LLM, BaseLLM
22+
from langchain_core.language_models.llms import LLM
2323

2424

25-
def get_llm_instance_wrapper(
26-
llm_instance: Union[LLM, BaseLLM], llm_type: str
27-
) -> Type[LLM]:
25+
def get_llm_instance_wrapper(llm_instance: LLM, llm_type: str) -> Type[LLM]:
2826
"""Wraps an LLM instance in a class that can be registered with LLMRails.
2927
3028
This is useful to create specific types of LLMs using a generic LLM provider
@@ -47,7 +45,7 @@ def model_kwargs(self):
4745
These are needed to allow changes to the arguments of the LLM calls.
4846
"""
4947
if hasattr(llm_instance, "model_kwargs"):
50-
return llm_instance.model_kwargs
48+
return llm_instance.model_kwargs # type: ignore[attr-defined] (We check in line above)
5149
return {}
5250

5351
@property
@@ -66,26 +64,29 @@ def _modify_instance_kwargs(self):
6664
"""
6765

6866
if hasattr(llm_instance, "model_kwargs"):
69-
if isinstance(llm_instance.model_kwargs, dict):
70-
llm_instance.model_kwargs["temperature"] = self.temperature
71-
llm_instance.model_kwargs["streaming"] = self.streaming
67+
model_kwargs = getattr(llm_instance, "model_kwargs")
68+
if isinstance(model_kwargs, dict):
69+
model_kwargs["temperature"] = self.temperature
70+
model_kwargs["streaming"] = self.streaming
7271

7372
def _call(
7473
self,
7574
prompt: str,
7675
stop: Optional[List[str]] = None,
7776
run_manager: Optional[CallbackManagerForLLMRun] = None,
77+
**kwargs,
7878
) -> str:
7979
self._modify_instance_kwargs()
80-
return llm_instance._call(prompt, stop, run_manager)
80+
return llm_instance._call(prompt, stop, run_manager, **kwargs)
8181

8282
async def _acall(
8383
self,
8484
prompt: str,
8585
stop: Optional[List[str]] = None,
8686
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
87+
**kwargs,
8788
) -> str:
8889
self._modify_instance_kwargs()
89-
return await llm_instance._acall(prompt, stop, run_manager)
90+
return await llm_instance._acall(prompt, stop, run_manager, **kwargs)
9091

9192
return WrapperLLM

nemoguardrails/llm/models/initializer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
from langchain_core.language_models import BaseChatModel
2121
from langchain_core.language_models.llms import BaseLLM
2222

23-
from .langchain_initializer import ModelInitializationError, init_langchain_model
23+
from nemoguardrails.llm.models.langchain_initializer import (
24+
ModelInitializationError,
25+
init_langchain_model,
26+
)
2427

2528

26-
# later we can easily conver it to a class
29+
# later we can easily convert it to a class
2730
def init_llm_model(
28-
model_name: Optional[str],
31+
model_name: str,
2932
provider_name: str,
3033
mode: Literal["chat", "text"],
3134
kwargs: Dict[str, Any],

nemoguardrails/llm/params.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
import logging
3838
import warnings
39-
from typing import Dict, Type
39+
from typing import Any, Dict, Type
4040

4141
from langchain.base_language import BaseLanguageModel
4242

@@ -61,18 +61,18 @@ def __init__(self, llm: BaseLanguageModel, **kwargs):
6161
warnings.warn(_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
6262
self.llm = llm
6363
self.altered_params = kwargs
64-
self.original_params = {}
64+
self.original_params: dict[str, Any] = {}
6565

6666
def __enter__(self):
6767
# Here we can access and modify the global language model parameters.
68-
self.original_params = {}
6968
for param, value in self.altered_params.items():
7069
if hasattr(self.llm, param):
7170
self.original_params[param] = getattr(self.llm, param)
7271
setattr(self.llm, param, value)
7372

7473
elif hasattr(self.llm, "model_kwargs"):
75-
if param not in self.llm.model_kwargs:
74+
model_kwargs = getattr(self.llm, "model_kwargs", {})
75+
if param not in model_kwargs:
7676
log.warning(
7777
"Parameter %s does not exist for %s. Passing to model_kwargs",
7878
param,
@@ -81,9 +81,10 @@ def __enter__(self):
8181

8282
self.original_params[param] = None
8383
else:
84-
self.original_params[param] = self.llm.model_kwargs[param]
84+
self.original_params[param] = model_kwargs[param]
8585

86-
self.llm.model_kwargs[param] = value
86+
model_kwargs[param] = value
87+
setattr(self.llm, "model_kwargs", model_kwargs)
8788

8889
else:
8990
log.warning(
@@ -92,7 +93,7 @@ def __enter__(self):
9293
self.llm.__class__.__name__,
9394
)
9495

95-
def __exit__(self, type, value, traceback):
96+
def __exit__(self, exc_type, value, traceback):
9697
# Restore original parameters when exiting the context
9798
for param, value in self.original_params.items():
9899
if hasattr(self.llm, param):

nemoguardrails/llm/providers/huggingface/pipeline.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,33 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import asyncio
1617
from typing import Any, List, Optional
1718

1819
from langchain.callbacks.manager import (
1920
AsyncCallbackManagerForLLMRun,
2021
CallbackManagerForLLMRun,
2122
)
2223
from langchain.schema.output import GenerationChunk
23-
from langchain_community.llms import HuggingFacePipeline
24+
25+
# Import HuggingFacePipeline with fallbacks for different LangChain versions
26+
HuggingFacePipeline = None # type: ignore[assignment]
27+
28+
try:
29+
from langchain_community.llms import (
30+
HuggingFacePipeline, # type: ignore[attr-defined,no-redef]
31+
)
32+
except ImportError:
33+
# Fallback for older versions of langchain
34+
try:
35+
from langchain.llms import (
36+
HuggingFacePipeline, # type: ignore[attr-defined,no-redef]
37+
)
38+
except ImportError:
39+
# Create a dummy class if HuggingFacePipeline is not available
40+
class HuggingFacePipeline: # type: ignore[misc,no-redef]
41+
def __init__(self, *args, **kwargs):
42+
raise ImportError("HuggingFacePipeline is not available")
2443

2544

2645
class HuggingFacePipelineCompatible(HuggingFacePipeline):
@@ -47,12 +66,13 @@ def _call(
4766
)
4867

4968
# Streaming for NeMo Guardrails is not supported in sync calls.
50-
if self.model_kwargs and self.model_kwargs.get("streaming"):
51-
raise Exception(
69+
model_kwargs = getattr(self, "model_kwargs", {})
70+
if model_kwargs and model_kwargs.get("streaming"):
71+
raise NotImplementedError(
5272
"Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!"
5373
)
5474

55-
llm_result = self._generate(
75+
llm_result = self._generate( # type: ignore[attr-defined]
5676
[prompt],
5777
stop=stop,
5878
run_manager=run_manager,
@@ -78,11 +98,12 @@ async def _acall(
7898
)
7999

80100
# Handle streaming, if the flag is set
81-
if self.model_kwargs and self.model_kwargs.get("streaming"):
101+
model_kwargs = getattr(self, "model_kwargs", {})
102+
if model_kwargs and model_kwargs.get("streaming"):
82103
# Retrieve the streamer object, needs to be set in model_kwargs
83-
streamer = self.model_kwargs.get("streamer")
104+
streamer = model_kwargs.get("streamer")
84105
if not streamer:
85-
raise Exception(
106+
raise ValueError(
86107
"Cannot stream, please add HuggingFace streamer object to model_kwargs!"
87108
)
88109

@@ -99,7 +120,7 @@ async def _acall(
99120
run_manager=run_manager,
100121
**kwargs,
101122
)
102-
loop.create_task(self._agenerate(**generation_kwargs))
123+
loop.create_task(getattr(self, "_agenerate")(**generation_kwargs))
103124

104125
# And start waiting for the chunks to come in.
105126
completion = ""
@@ -111,7 +132,7 @@ async def _acall(
111132

112133
return completion
113134

114-
llm_result = await self._agenerate(
135+
llm_result = await getattr(self, "_agenerate")(
115136
[prompt],
116137
stop=stop,
117138
run_manager=run_manager,

nemoguardrails/llm/providers/huggingface/streamers.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,27 @@
1414
# limitations under the License.
1515

1616
import asyncio
17+
from typing import TYPE_CHECKING, Optional
1718

18-
from transformers.generation.streamers import TextStreamer
19+
TRANSFORMERS_AVAILABLE = True
20+
try:
21+
from transformers.generation.streamers import ( # type: ignore[import-untyped]
22+
TextStreamer,
23+
)
24+
except ImportError:
25+
# Fallback if transformers is not available
26+
TRANSFORMERS_AVAILABLE = False
1927

28+
class TextStreamer: # type: ignore[no-redef]
29+
def __init__(self, *args, **kwargs):
30+
pass
2031

21-
class AsyncTextIteratorStreamer(TextStreamer):
32+
33+
if TYPE_CHECKING:
34+
from transformers import AutoTokenizer # type: ignore[import-untyped]
35+
36+
37+
class AsyncTextIteratorStreamer(TextStreamer): # type: ignore[misc]
2238
"""
2339
Simple async implementation for HuggingFace Transformers streamers.
2440
@@ -30,12 +46,14 @@ def __init__(
3046
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs
3147
):
3248
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
33-
self.text_queue = asyncio.Queue()
49+
self.text_queue: asyncio.Queue[str] = asyncio.Queue()
3450
self.stop_signal = None
35-
self.loop = None
51+
self.loop: Optional[asyncio.AbstractEventLoop] = None
3652

3753
def on_finalized_text(self, text: str, stream_end: bool = False):
3854
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
55+
if self.loop is None:
56+
return
3957
if len(text) > 0:
4058
asyncio.run_coroutine_threadsafe(self.text_queue.put(text), self.loop)
4159

nemoguardrails/llm/taskmanager.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def __init__(self, config: RailsConfig):
9595
def _get_general_instructions(self):
9696
"""Helper to extract the general instructions."""
9797
text = ""
98+
if self.config.instructions is None:
99+
return text
98100
for instruction in self.config.instructions:
99101
if instruction.type == "general":
100102
text = instruction.content
@@ -266,7 +268,9 @@ def render_task_prompt(
266268
task_prompt = self._render_string(
267269
prompt.content, context=context, events=events
268270
)
269-
while len(task_prompt) > prompt.max_length:
271+
while (
272+
prompt.max_length is not None and len(task_prompt) > prompt.max_length
273+
):
270274
if not events:
271275
raise Exception(
272276
f"Prompt exceeds max length of {prompt.max_length} characters even without history"
@@ -288,20 +292,27 @@ def render_task_prompt(
288292

289293
return task_prompt
290294
else:
295+
if prompt.messages is None:
296+
return []
291297
task_messages = self._render_messages(
292298
prompt.messages, context=context, events=events
293299
)
294300
task_prompt_length = self._get_messages_text_length(task_messages)
295-
while task_prompt_length > prompt.max_length:
301+
while (
302+
prompt.max_length is not None and task_prompt_length > prompt.max_length
303+
):
296304
if not events:
297305
raise Exception(
298306
f"Prompt exceeds max length of {prompt.max_length} characters even without history"
299307
)
300308
# Remove events from the beginning of the history until the prompt fits.
301309
events = events[1:]
302-
task_messages = self._render_messages(
303-
prompt.messages, context=context, events=events
304-
)
310+
if prompt.messages is not None:
311+
task_messages = self._render_messages(
312+
prompt.messages, context=context, events=events
313+
)
314+
else:
315+
task_messages = []
305316
task_prompt_length = self._get_messages_text_length(task_messages)
306317
return task_messages
307318

0 commit comments

Comments
 (0)