Skip to content

Commit 848dbb1

Browse files
Pouyanpitgasser-nv
authored andcommitted
feat(llm): pass llm params directly (#1387)
* feat(llm): add llm_params option to llm_call Extend llm_call to accept an optional llm_params dictionary for passing configuration parameters (e.g., temperature, max_tokens) to the language model. This enables more flexible control over LLM behavior during calls. refactor(llm): replace llm_params context manager with argument Update all usages of the llm_params context manager to pass llm_params as an argument to llm_call instead. This simplifies parameter handling and improves code clarity for LLM calls. docs: clarify prompt customization and llm_params usage update LLMChain config usage
1 parent 3629aed commit 848dbb1

File tree

21 files changed

+781
-183
lines changed

21 files changed

+781
-183
lines changed

docs/user-guides/advanced/prompt-customization.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ To override the prompt for any other custom purpose, you can specify the `mode`
5555
As an example of this, let's consider the case of compacting. Some applications might need concise prompts, for instance to avoid handling long contexts, and lower latency at the risk of slightly degraded performance due to the smaller context. For this, you might want to have multiple versions of a prompt for the same task and same model. This can be achieved as follows:
5656

5757
Task configuration:
58+
5859
```yaml
5960
models:
6061
- type: main
@@ -65,6 +66,7 @@ prompting_mode: "compact" # Default value is "standard"
6566
```
6667
6768
Prompts configuration:
69+
6870
```yaml
6971
prompts:
7072
- task: generate_user_intent
@@ -117,6 +119,7 @@ prompts:
117119
content: ...
118120
# ...
119121
```
122+
120123
For each task, you can also specify the maximum length of the prompt to be used for the LLM call in terms of the number of characters. This is useful if you want to limit the number of tokens used by the LLM or when you want to make sure that the prompt length does not exceed the maximum context length. When the maximum length is exceeded, the prompt is truncated by removing older turns from the conversation history until length of the prompt is less than or equal to the maximum length. The default maximum length is 16000 characters.
121124

122125
For example, for the `generate_user_intent` task, you can specify the following:
@@ -129,7 +132,6 @@ prompts:
129132
max_length: 3000
130133
```
131134

132-
133135
### Content Template
134136

135137
The content for a completion prompt or the body for a message in a chat prompt is a string that can also include variables and potentially other types of constructs. NeMo Guardrails uses [Jinja2](https://jinja.palletsprojects.com/) as the templating engine. Check out the [Jinja Synopsis](https://jinja.palletsprojects.com/en/3.1.x/templates/#synopsis) for a quick introduction.
@@ -200,7 +202,6 @@ Optionally, the output from the LLM can be parsed using an *output parser*. The
200202
- `bot_message`: parse the bot message, i.e., removes the "Bot message:" prefix if present;
201203
- `verbose_v1`: parse the output of the `verbose_v1` filter.
202204

203-
204205
## Predefined Prompts
205206

206207
Currently, the NeMo Guardrails toolkit includes prompts for `openai/gpt-3.5-turbo-instruct`, `openai/gpt-3.5-turbo`, `openai/gpt-4`, `databricks/dolly-v2-3b`, `cohere/command`, `cohere/command-light`, `cohere/command-light-nightly`.
@@ -232,8 +233,7 @@ prompt = llm_task_manager.render_task_prompt(
232233
},
233234
)
234235
235-
with llm_params(llm, temperature=0.0):
236-
check = await llm_call(llm, prompt)
236+
check = await llm_call(llm, prompt, llm_params={"temperature": 0.0})
237237
...
238238
```
239239

nemoguardrails/actions/llm/generation.py

Lines changed: 70 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
)
5656
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
5757
from nemoguardrails.kb.kb import KnowledgeBase
58-
from nemoguardrails.llm.params import llm_params
5958
from nemoguardrails.llm.prompts import get_prompt
6059
from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput
6160
from nemoguardrails.llm.types import Task
@@ -436,8 +435,9 @@ async def generate_user_intent(
436435
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_USER_INTENT.value))
437436

438437
# We make this call with temperature 0 to have it as deterministic as possible.
439-
with llm_params(llm, temperature=self.config.lowest_temperature):
440-
result = await llm_call(llm, prompt)
438+
result = await llm_call(
439+
llm, prompt, llm_params={"temperature": self.config.lowest_temperature}
440+
)
441441

442442
# Parse the output using the associated parser
443443
result = self.llm_task_manager.parse_task_output(
@@ -518,17 +518,15 @@ async def generate_user_intent(
518518
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
519519

520520
generation_options: GenerationOptions = generation_options_var.get()
521-
with llm_params(
521+
llm_params = (
522+
generation_options and generation_options.llm_params
523+
) or {}
524+
text = await llm_call(
522525
llm,
523-
**(
524-
(generation_options and generation_options.llm_params) or {}
525-
),
526-
):
527-
text = await llm_call(
528-
llm,
529-
prompt,
530-
custom_callback_handlers=[streaming_handler_var.get()],
531-
)
526+
prompt,
527+
custom_callback_handlers=[streaming_handler_var.get()],
528+
llm_params=llm_params,
529+
)
532530
text = self.llm_task_manager.parse_task_output(
533531
Task.GENERAL, output=text
534532
)
@@ -558,16 +556,16 @@ async def generate_user_intent(
558556
)
559557

560558
generation_options: GenerationOptions = generation_options_var.get()
561-
with llm_params(
559+
llm_params = (
560+
generation_options and generation_options.llm_params
561+
) or {}
562+
result = await llm_call(
562563
llm,
563-
**((generation_options and generation_options.llm_params) or {}),
564-
):
565-
result = await llm_call(
566-
llm,
567-
prompt,
568-
custom_callback_handlers=[streaming_handler_var.get()],
569-
stop=["User:"],
570-
)
564+
prompt,
565+
custom_callback_handlers=[streaming_handler_var.get()],
566+
stop=["User:"],
567+
llm_params=llm_params,
568+
)
571569

572570
text = self.llm_task_manager.parse_task_output(
573571
Task.GENERAL, output=result
@@ -662,8 +660,9 @@ async def generate_next_step(
662660
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_NEXT_STEPS.value))
663661

664662
# We use temperature 0 for next step prediction as well
665-
with llm_params(llm, temperature=self.config.lowest_temperature):
666-
result = await llm_call(llm, prompt)
663+
result = await llm_call(
664+
llm, prompt, llm_params={"temperature": self.config.lowest_temperature}
665+
)
667666

668667
# Parse the output using the associated parser
669668
result = self.llm_task_manager.parse_task_output(
@@ -924,23 +923,23 @@ async def generate_bot_message(
924923
prompt = context.get("user_message")
925924

926925
generation_options: GenerationOptions = generation_options_var.get()
927-
with llm_params(
926+
llm_params = (
927+
generation_options and generation_options.llm_params
928+
) or {}
929+
result = await llm_call(
928930
llm,
929-
**(
930-
(generation_options and generation_options.llm_params) or {}
931-
),
932-
):
933-
result = await llm_call(
934-
llm, prompt, custom_callback_handlers=[streaming_handler]
935-
)
931+
prompt,
932+
custom_callback_handlers=[streaming_handler],
933+
llm_params=llm_params,
934+
)
936935

937-
result = self.llm_task_manager.parse_task_output(
938-
Task.GENERAL, output=result
939-
)
936+
result = self.llm_task_manager.parse_task_output(
937+
Task.GENERAL, output=result
938+
)
940939

941-
result = _process_parsed_output(
942-
result, self._include_reasoning_traces()
943-
)
940+
result = _process_parsed_output(
941+
result, self._include_reasoning_traces()
942+
)
944943

945944
log.info(
946945
"--- :: LLM Bot Message Generation passthrough call took %.2f seconds",
@@ -987,13 +986,15 @@ async def generate_bot_message(
987986
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_BOT_MESSAGE.value))
988987

989988
generation_options: GenerationOptions = generation_options_var.get()
990-
with llm_params(
989+
llm_params = (
990+
generation_options and generation_options.llm_params
991+
) or {}
992+
result = await llm_call(
991993
llm,
992-
**((generation_options and generation_options.llm_params) or {}),
993-
):
994-
result = await llm_call(
995-
llm, prompt, custom_callback_handlers=[streaming_handler]
996-
)
994+
prompt,
995+
custom_callback_handlers=[streaming_handler],
996+
llm_params=llm_params,
997+
)
997998

998999
log.info(
9991000
"--- :: LLM Bot Message Generation call took %.2f seconds",
@@ -1094,8 +1095,9 @@ async def generate_value(
10941095
# Initialize the LLMCallInfo object
10951096
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_VALUE.value))
10961097

1097-
with llm_params(llm, temperature=self.config.lowest_temperature):
1098-
result = await llm_call(llm, prompt)
1098+
result = await llm_call(
1099+
llm, prompt, llm_params={"temperature": self.config.lowest_temperature}
1100+
)
10991101

11001102
# Parse the output using the associated parser
11011103
result = self.llm_task_manager.parse_task_output(
@@ -1269,32 +1271,28 @@ async def generate_intent_steps_message(
12691271
# We buffer the content, so we can get a chance to look at the
12701272
# first k lines.
12711273
await _streaming_handler.enable_buffering()
1272-
with llm_params(llm, temperature=self.config.lowest_temperature):
1273-
asyncio.create_task(
1274-
llm_call(
1275-
llm,
1276-
prompt,
1277-
custom_callback_handlers=[_streaming_handler],
1278-
stop=["\nuser ", "\nUser "],
1279-
)
1274+
asyncio.create_task(
1275+
llm_call(
1276+
llm,
1277+
prompt,
1278+
custom_callback_handlers=[_streaming_handler],
1279+
stop=["\nuser ", "\nUser "],
1280+
llm_params={"temperature": self.config.lowest_temperature},
12801281
)
1281-
result = await _streaming_handler.wait_top_k_nonempty_lines(k=2)
1282+
)
1283+
result = await _streaming_handler.wait_top_k_nonempty_lines(k=2)
12821284

1283-
# We also mark that the message is still being generated
1284-
# by a streaming handler.
1285-
result += (
1286-
f'\nBot message: "<<STREAMING[{_streaming_handler.uid}]>>"'
1287-
)
1285+
# We also mark that the message is still being generated
1286+
# by a streaming handler.
1287+
result += f'\nBot message: "<<STREAMING[{_streaming_handler.uid}]>>"'
12881288

1289-
# Moving forward we need to set the expected pattern to correctly
1290-
# parse the message.
1291-
# TODO: Figure out a more generic way to deal with this.
1292-
if prompt_config.output_parser == "verbose_v1":
1293-
_streaming_handler.set_pattern(
1294-
prefix='Bot message: "', suffix='"'
1295-
)
1296-
else:
1297-
_streaming_handler.set_pattern(prefix=' "', suffix='"')
1289+
# Moving forward we need to set the expected pattern to correctly
1290+
# parse the message.
1291+
# TODO: Figure out a more generic way to deal with this.
1292+
if prompt_config.output_parser == "verbose_v1":
1293+
_streaming_handler.set_pattern(prefix='Bot message: "', suffix='"')
1294+
else:
1295+
_streaming_handler.set_pattern(prefix=' "', suffix='"')
12981296
else:
12991297
# Initialize the LLMCallInfo object
13001298
llm_call_info_var.set(
@@ -1306,8 +1304,7 @@ async def generate_intent_steps_message(
13061304
**((generation_options and generation_options.llm_params) or {}),
13071305
"temperature": self.config.lowest_temperature,
13081306
}
1309-
with llm_params(llm, **additional_params):
1310-
result = await llm_call(llm, prompt)
1307+
result = await llm_call(llm, prompt, llm_params=additional_params)
13111308

13121309
# Parse the output using the associated parser
13131310
result = self.llm_task_manager.parse_task_output(
@@ -1388,10 +1385,8 @@ async def generate_intent_steps_message(
13881385

13891386
# We make this call with temperature 0 to have it as deterministic as possible.
13901387
generation_options: GenerationOptions = generation_options_var.get()
1391-
with llm_params(
1392-
llm, **((generation_options and generation_options.llm_params) or {})
1393-
):
1394-
result = await llm_call(llm, prompt)
1388+
llm_params = (generation_options and generation_options.llm_params) or {}
1389+
result = await llm_call(llm, prompt, llm_params=llm_params)
13951390

13961391
result = self.llm_task_manager.parse_task_output(
13971392
Task.GENERAL, output=result

nemoguardrails/actions/llm/utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,28 @@ async def llm_call(
7373
model_provider: Optional[str] = None,
7474
stop: Optional[List[str]] = None,
7575
custom_callback_handlers: Optional[List[AsyncCallbackHandler]] = None,
76+
llm_params: Optional[dict] = None,
7677
) -> str:
77-
"""Calls the LLM with a prompt and returns the generated text."""
78+
"""Calls the LLM with a prompt and returns the generated text.
79+
80+
Args:
81+
llm: The language model instance to use
82+
prompt: The prompt string or list of messages
83+
model_name: Optional model name for tracking
84+
model_provider: Optional model provider for tracking
85+
stop: Optional list of stop tokens
86+
custom_callback_handlers: Optional list of callback handlers
87+
llm_params: Optional configuration dictionary to pass to the LLM (e.g., temperature, max_tokens)
88+
89+
Returns:
90+
The generated text response
91+
"""
7892
_setup_llm_call_info(llm, model_name, model_provider)
7993
all_callbacks = _prepare_callbacks(custom_callback_handlers)
8094

95+
if llm_params and llm is not None:
96+
llm = llm.bind(**llm_params)
97+
8198
if isinstance(prompt, str):
8299
response = await _invoke_with_string_prompt(llm, prompt, all_callbacks, stop)
83100
else:

0 commit comments

Comments
 (0)