Skip to content

Commit e98e463

Browse files
rolshovenNathanHB
andauthored
Added litellm model config options and improved _prepare_max_new_tokens (#967)
* Added litellm model config options and made increase of `max_tokens` for reasoning models more general * Updated type hint for `timeout` model config option Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> * Removed redundant reasoning tag stripping from litellm model --------- Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com>
1 parent 7a456fd commit e98e463

File tree

2 files changed

+123
-44
lines changed

2 files changed

+123
-44
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ dependencies = [
8888
]
8989

9090
[project.optional-dependencies]
91-
litellm = ["litellm[caching]", "diskcache"]
91+
litellm = ["litellm[caching]>=1.66.0", "diskcache"]
9292
tgi = ["text-generation>=0.7.0"]
9393
optimum = ["optimum==1.12.0"]
9494
quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"]

src/lighteval/models/endpoints/litellm_model.py

Lines changed: 122 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import logging
2424
import time
2525
from concurrent.futures import ThreadPoolExecutor
26+
from json import JSONDecodeError
2627

28+
import requests
2729
from tqdm import tqdm
2830

2931
from lighteval.data import GenerativeTaskDataset
@@ -39,14 +41,15 @@
3941

4042
if is_package_available("litellm"):
4143
import litellm
42-
from litellm import encode
43-
from litellm.caching.caching import Cache
44+
from litellm import encode, supports_reasoning
45+
from litellm.caching.caching import Cache, LiteLLMCacheType
4446
from litellm.utils import ModelResponse as LitellmModelResponse
47+
from litellm.utils import get_max_tokens
4548

4649
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
4750
logging.getLogger("LiteLLM").handlers.clear()
4851

49-
litellm.cache = Cache(type="disk")
52+
litellm.cache = Cache(type=LiteLLMCacheType.DISK)
5053
else:
5154
from unittest.mock import Mock
5255

@@ -81,6 +84,18 @@ class LiteLLMModelConfig(ModelConfig):
8184
Maximum number of concurrent API requests to execute in parallel.
8285
Higher values can improve throughput for batch processing but may hit rate limits
8386
or exhaust API quotas faster. Default is 10.
87+
verbose (bool):
88+
Whether to enable verbose logging. Default is False.
89+
max_model_length (int | None):
90+
Maximum context length for the model. If None, infers the model's default max length.
91+
api_max_retry (int):
92+
Maximum number of retries for API requests. Default is 8.
93+
api_retry_sleep (float):
94+
Initial sleep time (in seconds) between retries. Default is 1.0.
95+
api_retry_multiplier (float):
96+
Multiplier for increasing sleep time between retries. Default is 2.0.
97+
timeout (float):
98+
Request timeout in seconds. Default is None (no timeout).
8499
generation_parameters (GenerationParameters, optional, defaults to empty GenerationParameters):
85100
Configuration parameters that control text generation behavior, including
86101
temperature, top_p, max_new_tokens, etc.
@@ -108,6 +123,13 @@ class LiteLLMModelConfig(ModelConfig):
108123
base_url: str | None = None
109124
api_key: str | None = None
110125
concurrent_requests: int = 10
126+
verbose: bool = False
127+
max_model_length: int | None = None
128+
129+
api_max_retry: int = 8
130+
api_retry_sleep: float = 1.0
131+
api_retry_multiplier: float = 2.0
132+
timeout: float | None = None
111133

112134

113135
@requires("litellm")
@@ -125,15 +147,17 @@ def __init__(self, config: LiteLLMModelConfig) -> None:
125147
self.api_key = config.api_key
126148
self.generation_parameters = config.generation_parameters
127149
self.concurrent_requests = config.concurrent_requests
150+
self._max_length = config.max_model_length
128151

129-
self.API_MAX_RETRY = 5
130-
self.API_RETRY_SLEEP = 3
131-
self.API_RETRY_MULTIPLIER = 2
152+
self.API_MAX_RETRY = config.api_max_retry
153+
self.API_RETRY_SLEEP = config.api_retry_sleep
154+
self.API_RETRY_MULTIPLIER = config.api_retry_multiplier
155+
self.timeout = config.timeout
132156

133157
self._tokenizer = encode
134158
self.pairwise_tokenization = False
135159
litellm.drop_params = True
136-
litellm.set_verbose = False
160+
litellm.verbose = config.verbose
137161
self.prompt_manager = PromptManager(
138162
use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt
139163
)
@@ -149,58 +173,65 @@ def _prepare_stop_sequence(self, stop_sequence):
149173
stop_sequence = [s for s in stop_sequence if s and s.strip()]
150174
return stop_sequence
151175

152-
def _prepare_max_new_tokens(self, max_new_tokens):
176+
def _prepare_max_new_tokens(self, max_new_tokens) -> int | None:
153177
"""Calculate completion tokens based on max_new_tokens."""
154178
if not max_new_tokens or max_new_tokens <= 0:
155179
return None
156180

157-
if "o1" in self.model:
181+
if supports_reasoning(self.model):
158182
# We need to allow more tokens to include reasoning tokens
159-
max_new_tokens = min(max_new_tokens * 10, 32000)
183+
max_new_tokens = min(max_new_tokens * 10, self.max_length)
184+
185+
logger.warning(
186+
f"Reasoning model detected, increasing max_new_tokens to {max_new_tokens} to allow for reasoning tokens",
187+
)
188+
160189
return max_new_tokens
161190

162191
def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence): # noqa: C901
163192
"""Make API call with retries."""
164193
response = LitellmModelResponse()
165-
for attempt in range(self.API_MAX_RETRY):
166-
try:
167-
stop_sequence = self._prepare_stop_sequence(stop_sequence)
168-
max_new_tokens = self._prepare_max_new_tokens(max_new_tokens)
169-
170-
if return_logits and not self.provider == "openai":
171-
logger.warning("Returning logits is not supported for this provider, ignoring.")
172-
173-
# Prepare kwargs for completion call
174-
kwargs = {
175-
"model": self.model,
176-
"messages": prompt,
177-
"logprobs": return_logits if self.provider == "openai" else None,
178-
"base_url": self.base_url,
179-
"n": num_samples,
180-
"caching": True,
181-
"api_key": self.api_key,
182-
}
194+
stop_sequence = self._prepare_stop_sequence(stop_sequence)
195+
max_new_tokens = self._prepare_max_new_tokens(max_new_tokens)
196+
197+
if return_logits and not self.provider == "openai":
198+
logger.warning("Returning logits is not supported for this provider, ignoring.")
199+
200+
# Prepare kwargs for completion call
201+
kwargs = {
202+
"model": self.model,
203+
"messages": prompt,
204+
"response_format": {"type": "text"},
205+
"max_tokens": max_new_tokens,
206+
"logprobs": return_logits if self.provider == "openai" else None,
207+
"stop": stop_sequence,
208+
"base_url": self.base_url,
209+
"api_key": self.api_key,
210+
"n": num_samples,
211+
"caching": True,
212+
"timeout": self.timeout,
213+
}
183214

184-
if num_samples > 1 and self.generation_parameters.temperature == 0:
185-
raise ValueError(
186-
"num_samples > 1 but temperature is set to 0, this will not sample different outputs."
187-
)
188-
189-
if "o1" in self.model:
190-
logger.warning("O1 models do not support temperature, top_p, stop sequence. Disabling.")
191-
else:
192-
kwargs.update(self.generation_parameters.to_litellm_dict())
215+
if "o1" in self.model:
216+
logger.warning("O1 models do not support temperature, top_p, stop sequence. Disabling.")
217+
else:
218+
kwargs.update(self.generation_parameters.to_litellm_dict())
193219

194-
if kwargs.get("max_completion_tokens", None) is None:
195-
kwargs["max_completion_tokens"] = max_new_tokens
220+
if kwargs.get("max_completion_tokens", None) is None:
221+
kwargs["max_completion_tokens"] = max_new_tokens
196222

223+
for attempt in range(self.API_MAX_RETRY):
224+
try:
197225
response = litellm.completion(**kwargs)
226+
content = response.choices[0].message.content
198227

199228
# If response is empty, retry without caching (maybe the error is recoverable and solved with a retry)
200-
if response.choices[0].message.content is None:
201-
kwargs["caching"] = False
229+
if not content:
202230
logger.info("Response is empty, retrying without caching")
231+
kwargs["caching"] = False
203232
response = litellm.completion(**kwargs)
233+
content = response.choices[0].message.content
234+
204235
return response
205236
except litellm.BadRequestError as e:
206237
if "message" in e.__dict__:
@@ -211,7 +242,9 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
211242
logger.warning(f"{error_string}. Returning empty response.")
212243
return LitellmModelResponse()
213244
except Exception as e:
214-
wait_time = min(64, self.API_RETRY_SLEEP * (2**attempt)) # Exponential backoff with max 64s
245+
wait_time = min(
246+
64, self.API_RETRY_SLEEP * (self.API_RETRY_MULTIPLIER**attempt)
247+
) # Exponential backoff with max 64s
215248
logger.warning(
216249
f"Error in API call: {e}, waiting {wait_time} seconds before retry {attempt + 1}/{self.API_MAX_RETRY}"
217250
)
@@ -259,6 +292,38 @@ def __call_api_parallel(
259292

260293
return results
261294

295+
def estimate_context_length(self) -> int:
296+
def fallback():
297+
logger.warning("Failed to fetch model endpoint info from OpenRouter, returning default max length.")
298+
return self._DEFAULT_MAX_LENGTH
299+
300+
# If the model is used through openrouter, the actual model name comes after the prefix
301+
model_name = self.model.removeprefix("openrouter/")
302+
endpoint_info_response = requests.get(
303+
f"https://openrouter.ai/api/v1/models/{model_name}/endpoints",
304+
headers={},
305+
)
306+
if endpoint_info_response.ok:
307+
try:
308+
endpoint_info = endpoint_info_response.json()
309+
context_lengths = {
310+
endpoint["provider_name"]: endpoint["context_length"]
311+
for endpoint in endpoint_info["data"]["endpoints"]
312+
}
313+
314+
if self.provider in context_lengths:
315+
return context_lengths[self.provider]
316+
317+
min_length = min(context_lengths.values())
318+
logger.warning(
319+
f"Estimating model context length as the minimum context length from available OpenRouter providers: {min_length}"
320+
)
321+
return min_length
322+
except (KeyError, TypeError, ValueError, JSONDecodeError):
323+
return fallback()
324+
325+
return fallback()
326+
262327
@cached(SamplingMethod.GENERATIVE)
263328
def greedy_until(
264329
self,
@@ -322,7 +387,21 @@ def add_special_tokens(self) -> bool:
322387
@property
323388
def max_length(self) -> int:
324389
"""Return the maximum sequence length of the model."""
325-
return 4096
390+
if self._max_length is not None:
391+
return self._max_length
392+
393+
try:
394+
max_tokens = get_max_tokens(self.model)
395+
except Exception:
396+
logger.error(
397+
f"Unable to get the maximum sequence length for model {self.model} from litellm. Fetching information from OpenRouter instead."
398+
)
399+
max_tokens = self.estimate_context_length()
400+
401+
# Avoid future requests
402+
self._max_length = max_tokens
403+
404+
return max_tokens
326405

327406
@cached(SamplingMethod.LOGPROBS)
328407
def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]:

0 commit comments

Comments
 (0)