2323import logging
2424import time
2525from concurrent .futures import ThreadPoolExecutor
26+ from json import JSONDecodeError
2627
28+ import requests
2729from tqdm import tqdm
2830
2931from lighteval .data import GenerativeTaskDataset
3941
4042if is_package_available ("litellm" ):
4143 import litellm
42- from litellm import encode
43- from litellm .caching .caching import Cache
44+ from litellm import encode , supports_reasoning
45+ from litellm .caching .caching import Cache , LiteLLMCacheType
4446 from litellm .utils import ModelResponse as LitellmModelResponse
47+ from litellm .utils import get_max_tokens
4548
4649 logging .getLogger ("LiteLLM" ).setLevel (logging .WARNING )
4750 logging .getLogger ("LiteLLM" ).handlers .clear ()
4851
49- litellm .cache = Cache (type = "disk" )
52+ litellm .cache = Cache (type = LiteLLMCacheType . DISK )
5053else :
5154 from unittest .mock import Mock
5255
@@ -81,6 +84,18 @@ class LiteLLMModelConfig(ModelConfig):
8184 Maximum number of concurrent API requests to execute in parallel.
8285 Higher values can improve throughput for batch processing but may hit rate limits
8386 or exhaust API quotas faster. Default is 10.
87+ verbose (bool):
88+ Whether to enable verbose logging. Default is False.
89+ max_model_length (int | None):
90+ Maximum context length for the model. If None, infers the model's default max length.
91+ api_max_retry (int):
92+ Maximum number of retries for API requests. Default is 8.
93+ api_retry_sleep (float):
94+ Initial sleep time (in seconds) between retries. Default is 1.0.
95+ api_retry_multiplier (float):
96+ Multiplier for increasing sleep time between retries. Default is 2.0.
97+ timeout (float):
98+ Request timeout in seconds. Default is None (no timeout).
8499 generation_parameters (GenerationParameters, optional, defaults to empty GenerationParameters):
85100 Configuration parameters that control text generation behavior, including
86101 temperature, top_p, max_new_tokens, etc.
@@ -108,6 +123,13 @@ class LiteLLMModelConfig(ModelConfig):
108123 base_url : str | None = None
109124 api_key : str | None = None
110125 concurrent_requests : int = 10
126+ verbose : bool = False
127+ max_model_length : int | None = None
128+
129+ api_max_retry : int = 8
130+ api_retry_sleep : float = 1.0
131+ api_retry_multiplier : float = 2.0
132+ timeout : float | None = None
111133
112134
113135@requires ("litellm" )
@@ -125,15 +147,17 @@ def __init__(self, config: LiteLLMModelConfig) -> None:
125147 self .api_key = config .api_key
126148 self .generation_parameters = config .generation_parameters
127149 self .concurrent_requests = config .concurrent_requests
150+ self ._max_length = config .max_model_length
128151
129- self .API_MAX_RETRY = 5
130- self .API_RETRY_SLEEP = 3
131- self .API_RETRY_MULTIPLIER = 2
152+ self .API_MAX_RETRY = config .api_max_retry
153+ self .API_RETRY_SLEEP = config .api_retry_sleep
154+ self .API_RETRY_MULTIPLIER = config .api_retry_multiplier
155+ self .timeout = config .timeout
132156
133157 self ._tokenizer = encode
134158 self .pairwise_tokenization = False
135159 litellm .drop_params = True
136- litellm .set_verbose = False
160+ litellm .verbose = config . verbose
137161 self .prompt_manager = PromptManager (
138162 use_chat_template = True , tokenizer = self .tokenizer , system_prompt = config .system_prompt
139163 )
@@ -149,58 +173,65 @@ def _prepare_stop_sequence(self, stop_sequence):
149173 stop_sequence = [s for s in stop_sequence if s and s .strip ()]
150174 return stop_sequence
151175
152- def _prepare_max_new_tokens (self , max_new_tokens ):
176+ def _prepare_max_new_tokens (self , max_new_tokens ) -> int | None :
153177 """Calculate completion tokens based on max_new_tokens."""
154178 if not max_new_tokens or max_new_tokens <= 0 :
155179 return None
156180
157- if "o1" in self .model :
181+ if supports_reasoning ( self .model ) :
158182 # We need to allow more tokens to include reasoning tokens
159- max_new_tokens = min (max_new_tokens * 10 , 32000 )
183+ max_new_tokens = min (max_new_tokens * 10 , self .max_length )
184+
185+ logger .warning (
186+ f"Reasoning model detected, increasing max_new_tokens to { max_new_tokens } to allow for reasoning tokens" ,
187+ )
188+
160189 return max_new_tokens
161190
162191 def __call_api (self , prompt , return_logits , max_new_tokens , num_samples , stop_sequence ): # noqa: C901
163192 """Make API call with retries."""
164193 response = LitellmModelResponse ()
165- for attempt in range (self .API_MAX_RETRY ):
166- try :
167- stop_sequence = self ._prepare_stop_sequence (stop_sequence )
168- max_new_tokens = self ._prepare_max_new_tokens (max_new_tokens )
169-
170- if return_logits and not self .provider == "openai" :
171- logger .warning ("Returning logits is not supported for this provider, ignoring." )
172-
173- # Prepare kwargs for completion call
174- kwargs = {
175- "model" : self .model ,
176- "messages" : prompt ,
177- "logprobs" : return_logits if self .provider == "openai" else None ,
178- "base_url" : self .base_url ,
179- "n" : num_samples ,
180- "caching" : True ,
181- "api_key" : self .api_key ,
182- }
194+ stop_sequence = self ._prepare_stop_sequence (stop_sequence )
195+ max_new_tokens = self ._prepare_max_new_tokens (max_new_tokens )
196+
197+ if return_logits and not self .provider == "openai" :
198+ logger .warning ("Returning logits is not supported for this provider, ignoring." )
199+
200+ # Prepare kwargs for completion call
201+ kwargs = {
202+ "model" : self .model ,
203+ "messages" : prompt ,
204+ "response_format" : {"type" : "text" },
205+ "max_tokens" : max_new_tokens ,
206+ "logprobs" : return_logits if self .provider == "openai" else None ,
207+ "stop" : stop_sequence ,
208+ "base_url" : self .base_url ,
209+ "api_key" : self .api_key ,
210+ "n" : num_samples ,
211+ "caching" : True ,
212+ "timeout" : self .timeout ,
213+ }
183214
184- if num_samples > 1 and self .generation_parameters .temperature == 0 :
185- raise ValueError (
186- "num_samples > 1 but temperature is set to 0, this will not sample different outputs."
187- )
188-
189- if "o1" in self .model :
190- logger .warning ("O1 models do not support temperature, top_p, stop sequence. Disabling." )
191- else :
192- kwargs .update (self .generation_parameters .to_litellm_dict ())
215+ if "o1" in self .model :
216+ logger .warning ("O1 models do not support temperature, top_p, stop sequence. Disabling." )
217+ else :
218+ kwargs .update (self .generation_parameters .to_litellm_dict ())
193219
194- if kwargs .get ("max_completion_tokens" , None ) is None :
195- kwargs ["max_completion_tokens" ] = max_new_tokens
220+ if kwargs .get ("max_completion_tokens" , None ) is None :
221+ kwargs ["max_completion_tokens" ] = max_new_tokens
196222
223+ for attempt in range (self .API_MAX_RETRY ):
224+ try :
197225 response = litellm .completion (** kwargs )
226+ content = response .choices [0 ].message .content
198227
199228 # If response is empty, retry without caching (maybe the error is recoverable and solved with a retry)
200- if response .choices [0 ].message .content is None :
201- kwargs ["caching" ] = False
229+ if not content :
202230 logger .info ("Response is empty, retrying without caching" )
231+ kwargs ["caching" ] = False
203232 response = litellm .completion (** kwargs )
233+ content = response .choices [0 ].message .content
234+
204235 return response
205236 except litellm .BadRequestError as e :
206237 if "message" in e .__dict__ :
@@ -211,7 +242,9 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
211242 logger .warning (f"{ error_string } . Returning empty response." )
212243 return LitellmModelResponse ()
213244 except Exception as e :
214- wait_time = min (64 , self .API_RETRY_SLEEP * (2 ** attempt )) # Exponential backoff with max 64s
245+ wait_time = min (
246+ 64 , self .API_RETRY_SLEEP * (self .API_RETRY_MULTIPLIER ** attempt )
247+ ) # Exponential backoff with max 64s
215248 logger .warning (
216249 f"Error in API call: { e } , waiting { wait_time } seconds before retry { attempt + 1 } /{ self .API_MAX_RETRY } "
217250 )
@@ -259,6 +292,38 @@ def __call_api_parallel(
259292
260293 return results
261294
295+ def estimate_context_length (self ) -> int :
296+ def fallback ():
297+ logger .warning ("Failed to fetch model endpoint info from OpenRouter, returning default max length." )
298+ return self ._DEFAULT_MAX_LENGTH
299+
300+ # If the model is used through openrouter, the actual model name comes after the prefix
301+ model_name = self .model .removeprefix ("openrouter/" )
302+ endpoint_info_response = requests .get (
303+ f"https://openrouter.ai/api/v1/models/{ model_name } /endpoints" ,
304+ headers = {},
305+ )
306+ if endpoint_info_response .ok :
307+ try :
308+ endpoint_info = endpoint_info_response .json ()
309+ context_lengths = {
310+ endpoint ["provider_name" ]: endpoint ["context_length" ]
311+ for endpoint in endpoint_info ["data" ]["endpoints" ]
312+ }
313+
314+ if self .provider in context_lengths :
315+ return context_lengths [self .provider ]
316+
317+ min_length = min (context_lengths .values ())
318+ logger .warning (
319+ f"Estimating model context length as the minimum context length from available OpenRouter providers: { min_length } "
320+ )
321+ return min_length
322+ except (KeyError , TypeError , ValueError , JSONDecodeError ):
323+ return fallback ()
324+
325+ return fallback ()
326+
262327 @cached (SamplingMethod .GENERATIVE )
263328 def greedy_until (
264329 self ,
@@ -322,7 +387,21 @@ def add_special_tokens(self) -> bool:
322387 @property
323388 def max_length (self ) -> int :
324389 """Return the maximum sequence length of the model."""
325- return 4096
390+ if self ._max_length is not None :
391+ return self ._max_length
392+
393+ try :
394+ max_tokens = get_max_tokens (self .model )
395+ except Exception :
396+ logger .error (
397+ f"Unable to get the maximum sequence length for model { self .model } from litellm. Fetching information from OpenRouter instead."
398+ )
399+ max_tokens = self .estimate_context_length ()
400+
401+ # Avoid future requests
402+ self ._max_length = max_tokens
403+
404+ return max_tokens
326405
327406 @cached (SamplingMethod .LOGPROBS )
328407 def loglikelihood (self , docs : list [Doc ]) -> list [ModelResponse ]:
0 commit comments