2828
2929from fastapi import FastAPI , Request
3030from fastapi .middleware .cors import CORSMiddleware
31- from openai .types .chat .chat_completion import ChatCompletion , Choice
31+ from openai .types .chat .chat_completion import Choice
32+ from openai .types .chat .chat_completion_message import ChatCompletionMessage
3233from openai .types .model import Model
3334from pydantic import BaseModel , Field , root_validator , validator
3435from starlette .responses import StreamingResponse
@@ -191,88 +192,82 @@ async def root_handler():
191192app .single_config_id = None
192193
193194
194- class RequestBody (ChatCompletion ):
195+ class RequestBody (BaseModel ):
195196 config_id : Optional [str ] = Field (
196197 default = os .getenv ("DEFAULT_CONFIG_ID" , None ),
197198 description = "The id of the configuration to be used. If not set, the default configuration will be used." ,
198199 )
199200 config_ids : Optional [List [str ]] = Field (
200201 default = None ,
201- description = "The list of configuration ids to be used. "
202- "If set, the configurations will be combined." ,
203- # alias="guardrails",
204- validate_default = True ,
202+ description = "The ids of the configurations to be used. If not set, the default configuration will be used." ,
205203 )
206204 thread_id : Optional [str ] = Field (
207205 default = None ,
208206 min_length = 16 ,
209207 max_length = 255 ,
210208 description = "The id of an existing thread to which the messages should be added." ,
211209 )
212- model : Optional [str ] = Field (
213- default = None ,
214- description = "The model used for the chat completion." ,
210+ messages : Optional [List [dict ]] = Field (
211+ default = None , description = "The list of messages in the current conversation."
215212 )
216- id : Optional [str ] = Field (
213+ context : Optional [dict ] = Field (
217214 default = None ,
218- description = "The id of the chat completion." ,
215+ description = "Additional context data to be added to the conversation." ,
216+ )
217+ stream : Optional [bool ] = Field (
218+ default = False ,
219+ description = "If set, partial message deltas will be sent, like in ChatGPT. "
220+ "Tokens will be sent as data-only server-sent events as they become "
221+ "available, with the stream terminated by a data: [DONE] message." ,
219222 )
220- object : Optional [ str ] = Field (
221- default = "chat.completion" ,
222- description = "The object type, which is always chat.completion " ,
223+ options : GenerationOptions = Field (
224+ default_factory = GenerationOptions ,
225+ description = "Additional options for controlling the generation. " ,
223226 )
224- created : Optional [int ] = Field (
227+ state : Optional [dict ] = Field (
225228 default = None ,
226- description = "The Unix timestamp (in seconds) of when the chat completion was created ." ,
229+ description = "A state object that should be used to continue the interaction ." ,
227230 )
228- choices : Optional [List [Choice ]] = Field (
231+ # Standard OpenAI completion parameters
232+ model : Optional [str ] = Field (
229233 default = None ,
230- description = "The list of choices for the chat completion." ,
234+ description = "The model to use for chat completion. Maps to config_id for backward compatibility ." ,
231235 )
232236 max_tokens : Optional [int ] = Field (
233237 default = None ,
234238 description = "The maximum number of tokens to generate." ,
235239 )
236240 temperature : Optional [float ] = Field (
237241 default = None ,
238- description = "The temperature to use for the chat completion ." ,
242+ description = "Sampling temperature to use." ,
239243 )
240244 top_p : Optional [float ] = Field (
241245 default = None ,
242- description = "The top p to use for the chat completion ." ,
246+ description = "Top-p sampling parameter ." ,
243247 )
244- stop : Optional [Union [ str , List [ str ]] ] = Field (
248+ stop : Optional [str ] = Field (
245249 default = None ,
246- description = "The stop sequences to use for the chat completion ." ,
250+ description = "Stop sequences." ,
247251 )
248252 presence_penalty : Optional [float ] = Field (
249253 default = None ,
250- description = "The presence penalty to use for the chat completion ." ,
254+ description = "Presence penalty parameter ." ,
251255 )
252256 frequency_penalty : Optional [float ] = Field (
253257 default = None ,
254- description = "The frequency penalty to use for the chat completion ." ,
258+ description = "Frequency penalty parameter ." ,
255259 )
256- messages : Optional [List [dict ]] = Field (
257- default = None , description = "The list of messages in the current conversation."
258- )
259- context : Optional [dict ] = Field (
260+ function_call : Optional [dict ] = Field (
260261 default = None ,
261- description = "Additional context data to be added to the conversation ." ,
262+ description = "Function call parameter ." ,
262263 )
263- stream : Optional [bool ] = Field (
264- default = False ,
265- description = "If set, partial message deltas will be sent, like in ChatGPT. "
266- "Tokens will be sent as data-only server-sent events as they become "
267- "available, with the stream terminated by a data: [DONE] message." ,
268- )
269- options : GenerationOptions = Field (
270- default_factory = GenerationOptions ,
271- description = "Additional options for controlling the generation." ,
264+ logit_bias : Optional [dict ] = Field (
265+ default = None ,
266+ description = "Logit bias parameter." ,
272267 )
273- state : Optional [dict ] = Field (
268+ log_probs : Optional [bool ] = Field (
274269 default = None ,
275- description = "A state object that should be used to continue the interaction ." ,
270+ description = "Log probabilities parameter ." ,
276271 )
277272
278273 @root_validator (pre = True )
@@ -537,16 +532,16 @@ async def chat_completion(body: RequestBody, request: Request):
537532 id = f"chatcmpl-{ uuid .uuid4 ()} " ,
538533 object = "chat.completion" ,
539534 created = int (time .time ()),
540- model = config_ids [0 ] if config_ids else None ,
535+ model = config_ids [0 ] if config_ids else "unknown" ,
541536 choices = [
542537 Choice (
543538 index = 0 ,
544- message = {
545- " content" : f"Could not load the { config_ids } guardrails configuration. "
539+ message = ChatCompletionMessage (
540+ content = f"Could not load the { config_ids } guardrails configuration. "
546541 f"An internal error has occurred." ,
547- " role" : "assistant" ,
548- } ,
549- finish_reason = "error " ,
542+ role = "assistant" ,
543+ ) ,
544+ finish_reason = "stop " ,
550545 logprobs = None ,
551546 )
552547 ],
@@ -570,15 +565,15 @@ async def chat_completion(body: RequestBody, request: Request):
570565 id = f"chatcmpl-{ uuid .uuid4 ()} " ,
571566 object = "chat.completion" ,
572567 created = int (time .time ()),
573- model = None ,
568+ model = config_ids [ 0 ] if config_ids else "unknown" ,
574569 choices = [
575570 Choice (
576571 index = 0 ,
577- message = {
578- " content" : "The `thread_id` must have a minimum length of 16 characters." ,
579- " role" : "assistant" ,
580- } ,
581- finish_reason = "error " ,
572+ message = ChatCompletionMessage (
573+ content = "The `thread_id` must have a minimum length of 16 characters." ,
574+ role = "assistant" ,
575+ ) ,
576+ finish_reason = "stop " ,
582577 logprobs = None ,
583578 )
584579 ],
@@ -625,7 +620,7 @@ async def chat_completion(body: RequestBody, request: Request):
625620 llm_rails .generate_async (
626621 messages = messages ,
627622 streaming_handler = streaming_handler ,
628- options = body . options ,
623+ options = generation_options ,
629624 state = body .state ,
630625 )
631626 )
@@ -638,7 +633,7 @@ async def chat_completion(body: RequestBody, request: Request):
638633 )
639634 else :
640635 res = await llm_rails .generate_async (
641- messages = messages , options = body . options , state = body .state
636+ messages = messages , options = generation_options , state = body .state
642637 )
643638
644639 if isinstance (res , GenerationResponse ):
@@ -662,11 +657,14 @@ async def chat_completion(body: RequestBody, request: Request):
662657 "id" : f"chatcmpl-{ uuid .uuid4 ()} " ,
663658 "object" : "chat.completion" ,
664659 "created" : int (time .time ()),
665- "model" : config_ids [0 ] if config_ids else None ,
660+ "model" : config_ids [0 ] if config_ids else "unknown" ,
666661 "choices" : [
667662 Choice (
668663 index = 0 ,
669- message = bot_message ,
664+ message = ChatCompletionMessage (
665+ role = "assistant" ,
666+ content = bot_message ["content" ],
667+ ),
670668 finish_reason = "stop" ,
671669 logprobs = None ,
672670 )
@@ -688,15 +686,15 @@ async def chat_completion(body: RequestBody, request: Request):
688686 id = f"chatcmpl-{ uuid .uuid4 ()} " ,
689687 object = "chat.completion" ,
690688 created = int (time .time ()),
691- model = None ,
689+ model = "unknown" ,
692690 choices = [
693691 Choice (
694692 index = 0 ,
695- message = {
696- " content" : "Internal server error" ,
697- " role" : "assistant" ,
698- } ,
699- finish_reason = "error " ,
693+ message = ChatCompletionMessage (
694+ content = "Internal server error" ,
695+ role = "assistant" ,
696+ ) ,
697+ finish_reason = "stop " ,
700698 logprobs = None ,
701699 )
702700 ],
0 commit comments