2424import uuid
2525import warnings
2626from contextlib import asynccontextmanager
27- from typing import Any , Callable , List , Optional
27+ from typing import Any , AsyncIterator , Callable , List , Optional , Union
2828
2929from fastapi import FastAPI , Request
3030from fastapi .middleware .cors import CORSMiddleware
31+ from openai .types .chat .chat_completion import ChatCompletion , Choice
32+ from openai .types .model import Model
3133from pydantic import Field , root_validator , validator
3234from starlette .responses import StreamingResponse
3335from starlette .staticfiles import StaticFiles
3638from nemoguardrails .rails .llm .options import GenerationOptions , GenerationResponse
3739from nemoguardrails .server .datastore .datastore import DataStore
3840from nemoguardrails .server .schemas .openai import (
39- Choice ,
40- Model ,
4141 ModelsResponse ,
42- OpenAIRequestFields ,
4342 ResponseBody ,
4443)
4544from nemoguardrails .streaming import StreamingHandler
@@ -195,7 +194,7 @@ async def root_handler():
195194app .single_config_id = None
196195
197196
198- class RequestBody (OpenAIRequestFields ):
197+ class RequestBody (ChatCompletion ):
199198 config_id : Optional [str ] = Field (
200199 default = os .getenv ("DEFAULT_CONFIG_ID" , None ),
201200 description = "The id of the configuration to be used. If not set, the default configuration will be used." ,
@@ -213,6 +212,50 @@ class RequestBody(OpenAIRequestFields):
213212 max_length = 255 ,
214213 description = "The id of an existing thread to which the messages should be added." ,
215214 )
215+ model : Optional [str ] = Field (
216+ default = None ,
217+ description = "The model used for the chat completion." ,
218+ )
219+ id : Optional [str ] = Field (
220+ default = None ,
221+ description = "The id of the chat completion." ,
222+ )
223+ object : Optional [str ] = Field (
224+ default = "chat.completion" ,
225+ description = "The object type, which is always chat.completion" ,
226+ )
227+ created : Optional [int ] = Field (
228+ default = None ,
229+ description = "The Unix timestamp (in seconds) of when the chat completion was created." ,
230+ )
231+ choices : Optional [List [Choice ]] = Field (
232+ default = None ,
233+ description = "The list of choices for the chat completion." ,
234+ )
235+ max_tokens : Optional [int ] = Field (
236+ default = None ,
237+ description = "The maximum number of tokens to generate." ,
238+ )
239+ temperature : Optional [float ] = Field (
240+ default = None ,
241+ description = "The temperature to use for the chat completion." ,
242+ )
243+ top_p : Optional [float ] = Field (
244+ default = None ,
245+ description = "The top p to use for the chat completion." ,
246+ )
247+ stop : Optional [Union [str , List [str ]]] = Field (
248+ default = None ,
249+ description = "The stop sequences to use for the chat completion." ,
250+ )
251+ presence_penalty : Optional [float ] = Field (
252+ default = None ,
253+ description = "The presence penalty to use for the chat completion." ,
254+ )
255+ frequency_penalty : Optional [float ] = Field (
256+ default = None ,
257+ description = "The frequency penalty to use for the chat completion." ,
258+ )
216259 messages : Optional [List [dict ]] = Field (
217260 default = None , description = "The list of messages in the current conversation."
218261 )
@@ -392,6 +435,73 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
392435 return llm_rails
393436
394437
438+ async def _format_streaming_response (
439+ streaming_handler : StreamingHandler , model_name : Optional [str ]
440+ ) -> AsyncIterator [str ]:
441+ while True :
442+ try :
443+ chunk = await streaming_handler .__anext__ ()
444+ except StopAsyncIteration :
445+ # When the stream ends, yield the [DONE] message
446+ yield "data: [DONE]\n \n "
447+ break
448+
449+ # Determine the payload format based on chunk type
450+ if isinstance (chunk , dict ):
451+ # If chunk is a dict, wrap it in OpenAI chunk format with delta
452+ payload = {
453+ "id" : None ,
454+ "object" : "chat.completion.chunk" ,
455+ "created" : int (time .time ()),
456+ "model" : model_name ,
457+ "choices" : [
458+ {
459+ "delta" : chunk ,
460+ "index" : None ,
461+ "finish_reason" : None ,
462+ }
463+ ],
464+ }
465+ elif isinstance (chunk , str ):
466+ try :
467+ # Try parsing as JSON - if it parses, it might be a pre-formed payload
468+ payload = json .loads (chunk )
469+ except Exception :
470+ # treat as plain text content token
471+ payload = {
472+ "id" : None ,
473+ "object" : "chat.completion.chunk" ,
474+ "created" : int (time .time ()),
475+ "model" : model_name ,
476+ "choices" : [
477+ {
478+ "delta" : {"content" : chunk },
479+ "index" : None ,
480+ "finish_reason" : None ,
481+ }
482+ ],
483+ }
484+ else :
485+ # For any other type, treat as plain content
486+ payload = {
487+ "id" : None ,
488+ "object" : "chat.completion.chunk" ,
489+ "created" : int (time .time ()),
490+ "model" : model_name ,
491+ "choices" : [
492+ {
493+ "delta" : {"content" : str (chunk )},
494+ "index" : None ,
495+ "finish_reason" : None ,
496+ }
497+ ],
498+ }
499+
500+ # Send the payload as JSON
501+ data = json .dumps (payload , ensure_ascii = False )
502+ yield f"data: { data } \n \n "
503+
504+
395505@app .post (
396506 "/v1/chat/completions" ,
397507 response_model = ResponseBody ,
@@ -523,7 +633,12 @@ async def chat_completion(body: RequestBody, request: Request):
523633 )
524634 )
525635
526- return StreamingResponse (streaming_handler )
636+ return StreamingResponse (
637+ _format_streaming_response (
638+ streaming_handler , model_name = config_ids [0 ] if config_ids else None
639+ ),
640+ media_type = "text/event-stream" ,
641+ )
527642 else :
528643 res = await llm_rails .generate_async (
529644 messages = messages , options = body .options , state = body .state
0 commit comments