@@ -38,6 +38,7 @@ def __init__(
3838 self .client = openai .OpenAI (
3939 api_key = self .api_key ,
4040 base_url = self .api_base ,
41+ timeout = int (self .timeout )
4142 )
4243
4344 # Only log unique models to reduce duplication
@@ -72,6 +73,16 @@ async def generate_with_context(
7273 "messages" : formatted_messages ,
7374 "max_completion_tokens" : kwargs .get ("max_tokens" , self .max_tokens ),
7475 }
76+ elif self .api_base == "https://integrate.api.nvidia.com/v1" :
77+ # add the branch for NVbuild
78+ params = {
79+ "model" : self .model ,
80+ "messages" : formatted_messages ,
81+ "temperature" : kwargs .get ("temperature" , self .temperature ),
82+ "top_p" : kwargs .get ("top_p" , self .top_p ),
83+ "max_tokens" : kwargs .get ("max_tokens" , self .max_tokens ),
84+ "stream" : True
85+ }
7586 else :
7687 params = {
7788 "model" : self .model ,
@@ -130,4 +141,14 @@ async def _call_api(self, params: Dict[str, Any]) -> str:
130141 logger = logging .getLogger (__name__ )
131142 logger .debug (f"API parameters: { params } " )
132143 logger .debug (f"API response: { response .choices [0 ].message .content } " )
133- return response .choices [0 ].message .content
144+
145+ if self .api_base == "https://integrate.api.nvidia.com/v1" :
146+ #print(f"{params['model']}")
147+ output = ""
148+ for chunk in response :
149+ if chunk .choices [0 ].delta .content is not None :
150+ output += chunk .choices [0 ].delta .content
151+ else :
152+ output = response .choices [0 ].message .content
153+
154+ return output
0 commit comments