Skip to content

Commit cffb1e9

Browse files
authored
fix: handle max_completeion_tokens error for newer openai models (#2413)
1 parent 4cb829d commit cffb1e9

File tree

1 file changed

+116
-36
lines changed

1 file changed

+116
-36
lines changed

src/ragas/llms/base.py

Lines changed: 116 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,110 @@ def __init__(
594594
# Check if client is async-capable at initialization
595595
self.is_async = self._check_client_async()
596596

597+
def _map_provider_params(self) -> t.Dict[str, t.Any]:
598+
"""Route to provider-specific parameter mapping.
599+
600+
Each provider may have different parameter requirements:
601+
- Google: Wraps parameters in generation_config and renames max_tokens
602+
- OpenAI: Maps max_tokens to max_completion_tokens for o-series models
603+
- Anthropic: No special handling required (pass-through)
604+
- LiteLLM: No special handling required (routes internally, pass-through)
605+
"""
606+
provider_lower = self.provider.lower()
607+
608+
if provider_lower == "google":
609+
return self._map_google_params()
610+
elif provider_lower == "openai":
611+
return self._map_openai_params()
612+
else:
613+
# Anthropic, LiteLLM - pass through unchanged
614+
return self.model_args.copy()
615+
616+
def _map_openai_params(self) -> t.Dict[str, t.Any]:
617+
"""Map max_tokens to max_completion_tokens for OpenAI reasoning models.
618+
619+
Reasoning models (o-series and gpt-5 series) require max_completion_tokens
620+
instead of the deprecated max_tokens parameter when using Chat Completions API.
621+
622+
Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.
623+
624+
Pattern-based matching for future-proof coverage:
625+
- O-series: o1, o2, o3, o4, o5, ... (all reasoning versions)
626+
- GPT-5 series: gpt-5, gpt-5-*, gpt-6, gpt-7, ... (all GPT-5+ models)
627+
- Other: codex-mini
628+
"""
629+
mapped_args = self.model_args.copy()
630+
631+
model_lower = self.model.lower()
632+
633+
# Pattern-based detection for reasoning models that require max_completion_tokens
634+
# Uses prefix matching to cover current and future model variants
635+
def is_reasoning_model(model_str: str) -> bool:
636+
"""Check if model is a reasoning model requiring max_completion_tokens."""
637+
# O-series reasoning models (o1, o1-mini, o1-2024-12-17, o2, o3, o4, o5, o6, o7, o8, o9)
638+
# Pattern: "o" followed by single digit 1-9, then optional "-" or end of string
639+
# TODO: Update to support o10+ when OpenAI releases models beyond o9
640+
if (
641+
len(model_str) >= 2
642+
and model_str[0] == "o"
643+
and model_str[1] in "123456789"
644+
):
645+
# Allow single digit o-series: o1, o2, ..., o9
646+
if len(model_str) == 2 or model_str[2] in ("-", "_"):
647+
return True
648+
649+
# GPT-5 and newer generation models (gpt-5, gpt-5-*, gpt-6, gpt-7, ..., gpt-19)
650+
# Pattern: "gpt-" followed by single or double digit >= 5, max 19
651+
# TODO: Update to support gpt-20+ when OpenAI releases models beyond gpt-19
652+
if model_str.startswith("gpt-"):
653+
version_str = (
654+
model_str[4:].split("-")[0].split("_")[0]
655+
) # Get version number
656+
try:
657+
version = int(version_str)
658+
if 5 <= version <= 19:
659+
return True
660+
except ValueError:
661+
pass
662+
663+
# Other specific reasoning models
664+
if model_str == "codex-mini":
665+
return True
666+
667+
return False
668+
669+
requires_max_completion_tokens = is_reasoning_model(model_lower)
670+
671+
# If max_tokens is provided and model requires max_completion_tokens, map it
672+
if requires_max_completion_tokens and "max_tokens" in mapped_args:
673+
mapped_args["max_completion_tokens"] = mapped_args.pop("max_tokens")
674+
675+
return mapped_args
676+
677+
def _map_google_params(self) -> t.Dict[str, t.Any]:
678+
"""Map parameters for Google Gemini models.
679+
680+
Google models require parameters to be wrapped in a generation_config dict,
681+
and max_tokens is renamed to max_output_tokens.
682+
"""
683+
google_kwargs = {}
684+
generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"}
685+
generation_config = {}
686+
687+
for key, value in self.model_args.items():
688+
if key in generation_config_keys:
689+
if key == "max_tokens":
690+
generation_config["max_output_tokens"] = value
691+
else:
692+
generation_config[key] = value
693+
else:
694+
google_kwargs[key] = value
695+
696+
if generation_config:
697+
google_kwargs["generation_config"] = generation_config
698+
699+
return google_kwargs
700+
597701
def _check_client_async(self) -> bool:
598702
"""Determine if the client is async-capable."""
599703
try:
@@ -676,34 +780,22 @@ def generate(
676780
self.agenerate(prompt, response_model)
677781
)
678782
else:
679-
if self.provider.lower() == "google":
680-
google_kwargs = {}
681-
generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"}
682-
generation_config = {}
683-
684-
for key, value in self.model_args.items():
685-
if key in generation_config_keys:
686-
if key == "max_tokens":
687-
generation_config["max_output_tokens"] = value
688-
else:
689-
generation_config[key] = value
690-
else:
691-
google_kwargs[key] = value
692-
693-
if generation_config:
694-
google_kwargs["generation_config"] = generation_config
783+
# Map parameters based on provider requirements
784+
provider_kwargs = self._map_provider_params()
695785

786+
if self.provider.lower() == "google":
696787
result = self.client.create(
697788
messages=messages,
698789
response_model=response_model,
699-
**google_kwargs,
790+
**provider_kwargs,
700791
)
701792
else:
793+
# OpenAI, Anthropic, LiteLLM
702794
result = self.client.chat.completions.create(
703795
model=self.model,
704796
messages=messages,
705797
response_model=response_model,
706-
**self.model_args,
798+
**provider_kwargs,
707799
)
708800

709801
# Track the usage
@@ -732,34 +824,22 @@ async def agenerate(
732824
"Cannot use agenerate() with a synchronous client. Use generate() instead."
733825
)
734826

735-
if self.provider.lower() == "google":
736-
google_kwargs = {}
737-
generation_config_keys = {"temperature", "max_tokens", "top_p", "top_k"}
738-
generation_config = {}
739-
740-
for key, value in self.model_args.items():
741-
if key in generation_config_keys:
742-
if key == "max_tokens":
743-
generation_config["max_output_tokens"] = value
744-
else:
745-
generation_config[key] = value
746-
else:
747-
google_kwargs[key] = value
748-
749-
if generation_config:
750-
google_kwargs["generation_config"] = generation_config
827+
# Map parameters based on provider requirements
828+
provider_kwargs = self._map_provider_params()
751829

830+
if self.provider.lower() == "google":
752831
result = await self.client.create(
753832
messages=messages,
754833
response_model=response_model,
755-
**google_kwargs,
834+
**provider_kwargs,
756835
)
757836
else:
837+
# OpenAI, Anthropic, LiteLLM
758838
result = await self.client.chat.completions.create(
759839
model=self.model,
760840
messages=messages,
761841
response_model=response_model,
762-
**self.model_args,
842+
**provider_kwargs,
763843
)
764844

765845
# Track the usage

0 commit comments

Comments
 (0)