@@ -594,6 +594,110 @@ def __init__(
594594 # Check if client is async-capable at initialization
595595 self .is_async = self ._check_client_async ()
596596
597+ def _map_provider_params (self ) -> t .Dict [str , t .Any ]:
598+ """Route to provider-specific parameter mapping.
599+
600+ Each provider may have different parameter requirements:
601+ - Google: Wraps parameters in generation_config and renames max_tokens
602+ - OpenAI: Maps max_tokens to max_completion_tokens for o-series models
603+ - Anthropic: No special handling required (pass-through)
604+ - LiteLLM: No special handling required (routes internally, pass-through)
605+ """
606+ provider_lower = self .provider .lower ()
607+
608+ if provider_lower == "google" :
609+ return self ._map_google_params ()
610+ elif provider_lower == "openai" :
611+ return self ._map_openai_params ()
612+ else :
613+ # Anthropic, LiteLLM - pass through unchanged
614+ return self .model_args .copy ()
615+
616+ def _map_openai_params (self ) -> t .Dict [str , t .Any ]:
617+ """Map max_tokens to max_completion_tokens for OpenAI reasoning models.
618+
619+ Reasoning models (o-series and gpt-5 series) require max_completion_tokens
620+ instead of the deprecated max_tokens parameter when using Chat Completions API.
621+
622+ Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.
623+
624+ Pattern-based matching for future-proof coverage:
625+ - O-series: o1, o2, o3, o4, o5, ... (all reasoning versions)
626+ - GPT-5 series: gpt-5, gpt-5-*, gpt-6, gpt-7, ... (all GPT-5+ models)
627+ - Other: codex-mini
628+ """
629+ mapped_args = self .model_args .copy ()
630+
631+ model_lower = self .model .lower ()
632+
633+ # Pattern-based detection for reasoning models that require max_completion_tokens
634+ # Uses prefix matching to cover current and future model variants
635+ def is_reasoning_model (model_str : str ) -> bool :
636+ """Check if model is a reasoning model requiring max_completion_tokens."""
637+ # O-series reasoning models (o1, o1-mini, o1-2024-12-17, o2, o3, o4, o5, o6, o7, o8, o9)
638+ # Pattern: "o" followed by single digit 1-9, then optional "-" or end of string
639+ # TODO: Update to support o10+ when OpenAI releases models beyond o9
640+ if (
641+ len (model_str ) >= 2
642+ and model_str [0 ] == "o"
643+ and model_str [1 ] in "123456789"
644+ ):
645+ # Allow single digit o-series: o1, o2, ..., o9
646+ if len (model_str ) == 2 or model_str [2 ] in ("-" , "_" ):
647+ return True
648+
649+ # GPT-5 and newer generation models (gpt-5, gpt-5-*, gpt-6, gpt-7, ..., gpt-19)
650+ # Pattern: "gpt-" followed by single or double digit >= 5, max 19
651+ # TODO: Update to support gpt-20+ when OpenAI releases models beyond gpt-19
652+ if model_str .startswith ("gpt-" ):
653+ version_str = (
654+ model_str [4 :].split ("-" )[0 ].split ("_" )[0 ]
655+ ) # Get version number
656+ try :
657+ version = int (version_str )
658+ if 5 <= version <= 19 :
659+ return True
660+ except ValueError :
661+ pass
662+
663+ # Other specific reasoning models
664+ if model_str == "codex-mini" :
665+ return True
666+
667+ return False
668+
669+ requires_max_completion_tokens = is_reasoning_model (model_lower )
670+
671+ # If max_tokens is provided and model requires max_completion_tokens, map it
672+ if requires_max_completion_tokens and "max_tokens" in mapped_args :
673+ mapped_args ["max_completion_tokens" ] = mapped_args .pop ("max_tokens" )
674+
675+ return mapped_args
676+
677+ def _map_google_params (self ) -> t .Dict [str , t .Any ]:
678+ """Map parameters for Google Gemini models.
679+
680+ Google models require parameters to be wrapped in a generation_config dict,
681+ and max_tokens is renamed to max_output_tokens.
682+ """
683+ google_kwargs = {}
684+ generation_config_keys = {"temperature" , "max_tokens" , "top_p" , "top_k" }
685+ generation_config = {}
686+
687+ for key , value in self .model_args .items ():
688+ if key in generation_config_keys :
689+ if key == "max_tokens" :
690+ generation_config ["max_output_tokens" ] = value
691+ else :
692+ generation_config [key ] = value
693+ else :
694+ google_kwargs [key ] = value
695+
696+ if generation_config :
697+ google_kwargs ["generation_config" ] = generation_config
698+
699+ return google_kwargs
700+
597701 def _check_client_async (self ) -> bool :
598702 """Determine if the client is async-capable."""
599703 try :
@@ -676,34 +780,22 @@ def generate(
676780 self .agenerate (prompt , response_model )
677781 )
678782 else :
679- if self .provider .lower () == "google" :
680- google_kwargs = {}
681- generation_config_keys = {"temperature" , "max_tokens" , "top_p" , "top_k" }
682- generation_config = {}
683-
684- for key , value in self .model_args .items ():
685- if key in generation_config_keys :
686- if key == "max_tokens" :
687- generation_config ["max_output_tokens" ] = value
688- else :
689- generation_config [key ] = value
690- else :
691- google_kwargs [key ] = value
692-
693- if generation_config :
694- google_kwargs ["generation_config" ] = generation_config
783+ # Map parameters based on provider requirements
784+ provider_kwargs = self ._map_provider_params ()
695785
786+ if self .provider .lower () == "google" :
696787 result = self .client .create (
697788 messages = messages ,
698789 response_model = response_model ,
699- ** google_kwargs ,
790+ ** provider_kwargs ,
700791 )
701792 else :
793+ # OpenAI, Anthropic, LiteLLM
702794 result = self .client .chat .completions .create (
703795 model = self .model ,
704796 messages = messages ,
705797 response_model = response_model ,
706- ** self . model_args ,
798+ ** provider_kwargs ,
707799 )
708800
709801 # Track the usage
@@ -732,34 +824,22 @@ async def agenerate(
732824 "Cannot use agenerate() with a synchronous client. Use generate() instead."
733825 )
734826
735- if self .provider .lower () == "google" :
736- google_kwargs = {}
737- generation_config_keys = {"temperature" , "max_tokens" , "top_p" , "top_k" }
738- generation_config = {}
739-
740- for key , value in self .model_args .items ():
741- if key in generation_config_keys :
742- if key == "max_tokens" :
743- generation_config ["max_output_tokens" ] = value
744- else :
745- generation_config [key ] = value
746- else :
747- google_kwargs [key ] = value
748-
749- if generation_config :
750- google_kwargs ["generation_config" ] = generation_config
827+ # Map parameters based on provider requirements
828+ provider_kwargs = self ._map_provider_params ()
751829
830+ if self .provider .lower () == "google" :
752831 result = await self .client .create (
753832 messages = messages ,
754833 response_model = response_model ,
755- ** google_kwargs ,
834+ ** provider_kwargs ,
756835 )
757836 else :
837+ # OpenAI, Anthropic, LiteLLM
758838 result = await self .client .chat .completions .create (
759839 model = self .model ,
760840 messages = messages ,
761841 response_model = response_model ,
762- ** self . model_args ,
842+ ** provider_kwargs ,
763843 )
764844
765845 # Track the usage
0 commit comments