Skip to content

Commit f15ce67

Browse files
committed
Check in rails/ type-fixes
1 parent ce28483 commit f15ce67

File tree

7 files changed

+305
-92
lines changed

7 files changed

+305
-92
lines changed

nemoguardrails/actions/llm/generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class LLMGenerationActions:
8282
def __init__(
8383
self,
8484
config: RailsConfig,
85-
llm: Union[BaseLLM, BaseChatModel],
85+
llm: Optional[Union[BaseLLM, BaseChatModel]],
8686
llm_task_manager: LLMTaskManager,
8787
get_embedding_search_provider_instance: Callable[
8888
[Optional[EmbeddingSearchProvider]], EmbeddingsIndex

nemoguardrails/context.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,42 @@
1414
# limitations under the License.
1515

1616
import contextvars
17-
from typing import Optional
17+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
1818

19-
streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None)
19+
if TYPE_CHECKING:
20+
from nemoguardrails.logging.explain import ExplainInfo
21+
from nemoguardrails.rails.llm.options import GenerationOptions, LLMStats
22+
from nemoguardrails.streaming import StreamingHandler
23+
24+
streaming_handler_var: contextvars.ContextVar[
25+
Optional["StreamingHandler"]
26+
] = contextvars.ContextVar("streaming_handler", default=None)
2027

2128
# The object that holds additional explanation information.
22-
explain_info_var = contextvars.ContextVar("explain_info", default=None)
29+
explain_info_var: contextvars.ContextVar[
30+
Optional["ExplainInfo"]
31+
] = contextvars.ContextVar("explain_info", default=None)
2332

2433
# The current LLM call.
25-
llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None)
34+
llm_call_info_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
35+
"llm_call_info", default=None
36+
)
2637

2738
# All the generation options applicable to the current context.
28-
generation_options_var = contextvars.ContextVar("generation_options", default=None)
39+
generation_options_var: contextvars.ContextVar[
40+
Optional["GenerationOptions"]
41+
] = contextvars.ContextVar("generation_options", default=None)
2942

3043
# The stats about the LLM calls.
31-
llm_stats_var = contextvars.ContextVar("llm_stats", default=None)
44+
llm_stats_var: contextvars.ContextVar[Optional["LLMStats"]] = contextvars.ContextVar(
45+
"llm_stats", default=None
46+
)
3247

3348
# The raw LLM request that comes from the user.
3449
# This is used in passthrough mode.
35-
raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None)
50+
raw_llm_request: contextvars.ContextVar[Optional[Any]] = contextvars.ContextVar(
51+
"raw_llm_request", default=None
52+
)
3653

3754
reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
3855
"reasoning_trace", default=None

nemoguardrails/rails/llm/buffer.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
# limitations under the License.
1515

1616
from abc import ABC, abstractmethod
17-
from typing import AsyncGenerator, List, NamedTuple
17+
from typing import TYPE_CHECKING, AsyncGenerator, List, NamedTuple
18+
19+
if TYPE_CHECKING:
20+
from collections.abc import AsyncIterator
1821

1922
from nemoguardrails.rails.llm.config import OutputRailsStreamingConfig
2023

@@ -111,9 +114,7 @@ def format_chunks(self, chunks: List[str]) -> str:
111114
...
112115

113116
@abstractmethod
114-
async def process_stream(
115-
self, streaming_handler
116-
) -> AsyncGenerator[ChunkBatch, None]:
117+
async def process_stream(self, streaming_handler):
117118
"""Process streaming chunks and yield chunk batches.
118119
119120
This is the main method that concrete buffer strategies must implement.
@@ -138,9 +139,9 @@ async def process_stream(
138139
... print(f"Processing: {context_formatted}")
139140
... print(f"User: {user_formatted}")
140141
"""
141-
...
142+
yield ChunkBatch([], []) # pragma: no cover
142143

143-
async def __call__(self, streaming_handler) -> AsyncGenerator[ChunkBatch, None]:
144+
async def __call__(self, streaming_handler):
144145
"""Callable interface that delegates to process_stream.
145146
146147
It delegates to the `process_stream` method and can
@@ -256,9 +257,7 @@ def from_config(cls, config: OutputRailsStreamingConfig):
256257
buffer_context_size=config.context_size, buffer_chunk_size=config.chunk_size
257258
)
258259

259-
async def process_stream(
260-
self, streaming_handler
261-
) -> AsyncGenerator[ChunkBatch, None]:
260+
async def process_stream(self, streaming_handler):
262261
"""Process streaming chunks using rolling buffer strategy.
263262
264263
This method implements the rolling buffer logic, accumulating chunks

nemoguardrails/rails/llm/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,9 @@ def _load_path(
11281128

11291129
# the first .railsignore file found from cwd down to its subdirectories
11301130
railsignore_path = utils.get_railsignore_path(config_path)
1131-
ignore_patterns = utils.get_railsignore_patterns(railsignore_path)
1131+
ignore_patterns = (
1132+
utils.get_railsignore_patterns(railsignore_path) if railsignore_path else set()
1133+
)
11321134

11331135
if os.path.isdir(config_path):
11341136
for root, _, files in os.walk(config_path, followlinks=True):
@@ -1245,8 +1247,8 @@ def _parse_colang_files_recursively(
12451247
current_file, current_path = colang_files[len(parsed_colang_files)]
12461248

12471249
with open(current_path, "r", encoding="utf-8") as f:
1250+
content = f.read()
12481251
try:
1249-
content = f.read()
12501252
_parsed_config = parse_colang_file(
12511253
current_file, content=content, version=colang_version
12521254
)
@@ -1748,7 +1750,7 @@ def streaming_supported(self):
17481750
# if we have output rails streaming enabled
17491751
# we keep it in case it was needed when we have
17501752
# support per rails
1751-
if self.rails.output.streaming.enabled:
1753+
if self.rails.output.streaming and self.rails.output.streaming.enabled:
17521754
return True
17531755
return False
17541756

0 commit comments

Comments
 (0)