Skip to content

Commit a8c4b43

Browse files
committed
adaption for NVIDIA API: stream=True
1 parent bc66c5b commit a8c4b43

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

openevolve/llm/openai.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
self.client = openai.OpenAI(
3939
api_key=self.api_key,
4040
base_url=self.api_base,
41+
timeout=int(self.timeout)
4142
)
4243

4344
# Only log unique models to reduce duplication
@@ -72,6 +73,16 @@ async def generate_with_context(
7273
"messages": formatted_messages,
7374
"max_completion_tokens": kwargs.get("max_tokens", self.max_tokens),
7475
}
76+
elif self.api_base == "https://integrate.api.nvidia.com/v1":
77+
# add the branch for NVbuild
78+
params = {
79+
"model": self.model,
80+
"messages": formatted_messages,
81+
"temperature": kwargs.get("temperature", self.temperature),
82+
"top_p": kwargs.get("top_p", self.top_p),
83+
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
84+
"stream": True
85+
}
7586
else:
7687
params = {
7788
"model": self.model,
@@ -130,4 +141,14 @@ async def _call_api(self, params: Dict[str, Any]) -> str:
130141
logger = logging.getLogger(__name__)
131142
logger.debug(f"API parameters: {params}")
132143
logger.debug(f"API response: {response.choices[0].message.content}")
133-
return response.choices[0].message.content
144+
145+
if self.api_base == "https://integrate.api.nvidia.com/v1":
146+
#print(f"{params['model']}")
147+
output=""
148+
for chunk in response:
149+
if chunk.choices[0].delta.content is not None:
150+
output += chunk.choices[0].delta.content
151+
else:
152+
output=response.choices[0].message.content
153+
154+
return output

0 commit comments

Comments
 (0)