2424import uuid
2525import warnings
2626from contextlib import asynccontextmanager
27- from typing import Any , Callable , List , Optional
27+ from typing import Any , AsyncIterator , Callable , List , Optional , Union
2828
2929from fastapi import FastAPI , Request
3030from fastapi .middleware .cors import CORSMiddleware
31- from pydantic import Field , root_validator , validator
31+ from openai .types .chat .chat_completion import ChatCompletion , Choice
32+ from openai .types .model import Model
33+ from pydantic import BaseModel , Field , root_validator , validator
3234from starlette .responses import StreamingResponse
3335from starlette .staticfiles import StaticFiles
3436
3537from nemoguardrails import LLMRails , RailsConfig , utils
3638from nemoguardrails .rails .llm .options import GenerationOptions , GenerationResponse
3739from nemoguardrails .server .datastore .datastore import DataStore
38- from nemoguardrails .server .schemas .openai import (
39- Choice ,
40- Model ,
41- ModelsResponse ,
42- OpenAIRequestFields ,
43- ResponseBody ,
44- )
40+ from nemoguardrails .server .schemas .openai import ModelsResponse , ResponseBody
4541from nemoguardrails .streaming import StreamingHandler
4642
4743logging .basicConfig (level = logging .INFO )
@@ -195,7 +191,7 @@ async def root_handler():
195191app .single_config_id = None
196192
197193
198- class RequestBody (OpenAIRequestFields ):
194+ class RequestBody (ChatCompletion ):
199195 config_id : Optional [str ] = Field (
200196 default = os .getenv ("DEFAULT_CONFIG_ID" , None ),
201197 description = "The id of the configuration to be used. If not set, the default configuration will be used." ,
@@ -213,6 +209,50 @@ class RequestBody(OpenAIRequestFields):
213209 max_length = 255 ,
214210 description = "The id of an existing thread to which the messages should be added." ,
215211 )
212+ model : Optional [str ] = Field (
213+ default = None ,
214+ description = "The model used for the chat completion." ,
215+ )
216+ id : Optional [str ] = Field (
217+ default = None ,
218+ description = "The id of the chat completion." ,
219+ )
220+ object : Optional [str ] = Field (
221+ default = "chat.completion" ,
222+ description = "The object type, which is always chat.completion" ,
223+ )
224+ created : Optional [int ] = Field (
225+ default = None ,
226+ description = "The Unix timestamp (in seconds) of when the chat completion was created." ,
227+ )
228+ choices : Optional [List [Choice ]] = Field (
229+ default = None ,
230+ description = "The list of choices for the chat completion." ,
231+ )
232+ max_tokens : Optional [int ] = Field (
233+ default = None ,
234+ description = "The maximum number of tokens to generate." ,
235+ )
236+ temperature : Optional [float ] = Field (
237+ default = None ,
238+ description = "The temperature to use for the chat completion." ,
239+ )
240+ top_p : Optional [float ] = Field (
241+ default = None ,
242+ description = "The top p to use for the chat completion." ,
243+ )
244+ stop : Optional [Union [str , List [str ]]] = Field (
245+ default = None ,
246+ description = "The stop sequences to use for the chat completion." ,
247+ )
248+ presence_penalty : Optional [float ] = Field (
249+ default = None ,
250+ description = "The presence penalty to use for the chat completion." ,
251+ )
252+ frequency_penalty : Optional [float ] = Field (
253+ default = None ,
254+ description = "The frequency penalty to use for the chat completion." ,
255+ )
216256 messages : Optional [List [dict ]] = Field (
217257 default = None , description = "The list of messages in the current conversation."
218258 )
@@ -392,6 +432,73 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
392432 return llm_rails
393433
394434
435+ async def _format_streaming_response (
436+ streaming_handler : StreamingHandler , model_name : Optional [str ]
437+ ) -> AsyncIterator [str ]:
438+ while True :
439+ try :
440+ chunk = await streaming_handler .__anext__ ()
441+ except StopAsyncIteration :
442+ # When the stream ends, yield the [DONE] message
443+ yield "data: [DONE]\n \n "
444+ break
445+
446+ # Determine the payload format based on chunk type
447+ if isinstance (chunk , dict ):
448+ # If chunk is a dict, wrap it in OpenAI chunk format with delta
449+ payload = {
450+ "id" : None ,
451+ "object" : "chat.completion.chunk" ,
452+ "created" : int (time .time ()),
453+ "model" : model_name ,
454+ "choices" : [
455+ {
456+ "delta" : chunk ,
457+ "index" : None ,
458+ "finish_reason" : None ,
459+ }
460+ ],
461+ }
462+ elif isinstance (chunk , str ):
463+ try :
464+ # Try parsing as JSON - if it parses, it might be a pre-formed payload
465+ payload = json .loads (chunk )
466+ except Exception :
467+ # treat as plain text content token
468+ payload = {
469+ "id" : None ,
470+ "object" : "chat.completion.chunk" ,
471+ "created" : int (time .time ()),
472+ "model" : model_name ,
473+ "choices" : [
474+ {
475+ "delta" : {"content" : chunk },
476+ "index" : None ,
477+ "finish_reason" : None ,
478+ }
479+ ],
480+ }
481+ else :
482+ # For any other type, treat as plain content
483+ payload = {
484+ "id" : None ,
485+ "object" : "chat.completion.chunk" ,
486+ "created" : int (time .time ()),
487+ "model" : model_name ,
488+ "choices" : [
489+ {
490+ "delta" : {"content" : str (chunk )},
491+ "index" : None ,
492+ "finish_reason" : None ,
493+ }
494+ ],
495+ }
496+
497+ # Send the payload as JSON
498+ data = json .dumps (payload , ensure_ascii = False )
499+ yield f"data: { data } \n \n "
500+
501+
395502@app .post (
396503 "/v1/chat/completions" ,
397504 response_model = ResponseBody ,
@@ -523,7 +630,12 @@ async def chat_completion(body: RequestBody, request: Request):
523630 )
524631 )
525632
526- return StreamingResponse (streaming_handler )
633+ return StreamingResponse (
634+ _format_streaming_response (
635+ streaming_handler , model_name = config_ids [0 ] if config_ids else None
636+ ),
637+ media_type = "text/event-stream" ,
638+ )
527639 else :
528640 res = await llm_rails .generate_async (
529641 messages = messages , options = body .options , state = body .state
0 commit comments