Skip to content

Commit 4b9d826

Browse files
Fix BadRequestError while using OpenAI O1 series models | Add option to bypass temperature (#2151)
we were about to use OpenAI's O1 model for FactualCorrectness metric since it is good at reasoning, but faced BadRequestError on ragas eventhough it worked fine in langchain and llama index, then I came to know ragas is calculating temperature, but it is not supported in O1 models, so added a way to bypass temperature Reproducible code ```python from langchain_openai import ChatOpenAI from llama_index.llms.openai import OpenAI from ragas.llms import LlamaIndexLLMWrapper,LangchainLLMWrapper from langchain_core.prompt_values import StringPromptValue api_key = "<your_api_key>" langchain_llm=ChatOpenAI(model="o1", api_key=api_key) llama_index_llm=OpenAI(model="o1", api_key=api_key) # Both Will work fine print(langchain_llm.invoke("hi")) print(llama_index_llm.complete("hi")) prompt = StringPromptValue(text="hi") # Both Will raise error await LlamaIndexLLMWrapper(llama_index_llm).agenerate_text(prompt) await LangchainLLMWrapper(langchain_llm).agenerate_text(prompt) ``` Error ``` BadRequestError: Error code: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'unsupported_parameter'}} ``` After Fix ```python # Both Will work as expected await LlamaIndexLLMWrapper(llama_index_llm,bypass_temperature=True).agenerate_text(prompt) await LangchainLLMWrapper(langchain_llm,bypass_temperature=True).agenerate_text(prompt) ``` Due to scalability reasons, I am adding a flag instead of checking the model name and removing the temperature Co-authored-by: Vignesh Arivazhagan <vignesh.arivazhagan@solitontech.com>
1 parent 02e0482 commit 4b9d826

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/ragas/llms/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,16 @@ def __init__(
142142
run_config: t.Optional[RunConfig] = None,
143143
is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None,
144144
cache: t.Optional[CacheInterface] = None,
145+
bypass_temperature: bool = False,
145146
):
146147
super().__init__(cache=cache)
147148
self.langchain_llm = langchain_llm
148149
if run_config is None:
149150
run_config = RunConfig()
150151
self.set_run_config(run_config)
151152
self.is_finished_parser = is_finished_parser
153+
# Certain LLMs (e.g., OpenAI o1 series) do not support temperature
154+
self.bypass_temperature = bypass_temperature
152155

153156
def is_finished(self, response: LLMResult) -> bool:
154157
"""
@@ -252,7 +255,7 @@ async def agenerate_text(
252255
old_temperature: float | None = None
253256
if temperature is None:
254257
temperature = self.get_temperature(n=n)
255-
if hasattr(self.langchain_llm, "temperature"):
258+
if hasattr(self.langchain_llm, "temperature") and not self.bypass_temperature:
256259
self.langchain_llm.temperature = temperature # type: ignore
257260
old_temperature = temperature
258261

@@ -311,9 +314,12 @@ def __init__(
311314
llm: BaseLLM,
312315
run_config: t.Optional[RunConfig] = None,
313316
cache: t.Optional[CacheInterface] = None,
317+
bypass_temperature: bool = False,
314318
):
315319
super().__init__(cache=cache)
316320
self.llm = llm
321+
# Certain LLMs (e.g., OpenAI o1 series) do not support temperature
322+
self.bypass_temperature = bypass_temperature
317323

318324
try:
319325
self._signature = type(self.llm).__name__.lower()
@@ -378,6 +384,10 @@ async def agenerate_text(
378384
temperature = self.get_temperature(n)
379385

380386
kwargs = self.check_args(n, temperature, stop, callbacks)
387+
388+
if self.bypass_temperature:
389+
kwargs.pop("temperature", None)
390+
381391
li_response = await self.llm.acomplete(prompt.to_string(), **kwargs)
382392

383393
return LLMResult(generations=[[Generation(text=li_response.text)]])

0 commit comments

Comments
 (0)