2020import os .path
2121import re
2222import time
23+ import uuid
2324import warnings
2425from contextlib import asynccontextmanager
2526from typing import Any , Callable , List , Optional
@@ -229,10 +230,53 @@ class RequestBody(BaseModel):
229230 default = None ,
230231 description = "A state object that should be used to continue the interaction." ,
231232 )
233+ # Standard OpenAI completion parameters
234+ model : Optional [str ] = Field (
235+ default = None ,
236+ description = "The model to use for chat completion. Maps to config_id for backward compatibility." ,
237+ )
238+ max_tokens : Optional [int ] = Field (
239+ default = None ,
240+ description = "The maximum number of tokens to generate." ,
241+ )
242+ temperature : Optional [float ] = Field (
243+ default = None ,
244+ description = "Sampling temperature to use." ,
245+ )
246+ top_p : Optional [float ] = Field (
247+ default = None ,
248+ description = "Top-p sampling parameter." ,
249+ )
250+ stop : Optional [str ] = Field (
251+ default = None ,
252+ description = "Stop sequences." ,
253+ )
254+ presence_penalty : Optional [float ] = Field (
255+ default = None ,
256+ description = "Presence penalty parameter." ,
257+ )
258+ frequency_penalty : Optional [float ] = Field (
259+ default = None ,
260+ description = "Frequency penalty parameter." ,
261+ )
262+ function_call : Optional [dict ] = Field (
263+ default = None ,
264+ description = "Function call parameter." ,
265+ )
266+ logit_bias : Optional [dict ] = Field (
267+ default = None ,
268+ description = "Logit bias parameter." ,
269+ )
270+ log_probs : Optional [bool ] = Field (
271+ default = None ,
272+ description = "Log probabilities parameter." ,
273+ )
232274
233275 @root_validator (pre = True )
234276 def ensure_config_id (cls , data : Any ) -> Any :
235277 if isinstance (data , dict ):
278+ if data .get ("model" ) is not None and data .get ("config_id" ) is None :
279+ data ["config_id" ] = data ["model" ]
236280 if data .get ("config_id" ) is not None and data .get ("config_ids" ) is not None :
237281 raise ValueError (
238282 "Only one of config_id or config_ids should be specified"
@@ -253,25 +297,113 @@ def ensure_config_ids(cls, v, values):
253297 return v
254298
255299
300+ class Choice (BaseModel ):
301+ index : Optional [int ] = Field (
302+ default = None , description = "The index of the choice in the list of choices."
303+ )
304+ messages : Optional [dict ] = Field (
305+ default = None , description = "The message of the choice"
306+ )
307+ logprobs : Optional [dict ] = Field (
308+ default = None , description = "The log probabilities of the choice"
309+ )
310+ finish_reason : Optional [str ] = Field (
311+ default = None , description = "The reason the model stopped generating tokens."
312+ )
313+
314+
256315class ResponseBody (BaseModel ):
257- messages : Optional [List [dict ]] = Field (
258- default = None , description = "The new messages in the conversation"
316+ # OpenAI-compatible fields
317+ id : Optional [str ] = Field (
318+ default = None , description = "A unique identifier for the chat completion."
259319 )
260- llm_output : Optional [ dict ] = Field (
261- default = None ,
262- description = "Contains any additional output coming from the LLM. " ,
320+ object : str = Field (
321+ default = "chat.completion" ,
322+ description = "The object type, which is always chat.completion " ,
263323 )
264- output_data : Optional [dict ] = Field (
324+ created : Optional [int ] = Field (
265325 default = None ,
266- description = "The output data, i.e. a dict with the values corresponding to the `output_vars`." ,
326+ description = "The Unix timestamp (in seconds) of when the chat completion was created." ,
327+ )
328+ model : Optional [str ] = Field (
329+ default = None , description = "The model used for the chat completion."
267330 )
268- log : Optional [GenerationLog ] = Field (
269- default = None , description = "Additional logging information ."
331+ choices : Optional [List [ Choice ] ] = Field (
332+ default = None , description = "A list of chat completion choices ."
270333 )
334+ # NeMo-Guardrails specific fields for backward compatibility
271335 state : Optional [dict ] = Field (
272- default = None ,
273- description = "A state object that should be used to continue the interaction in the future." ,
336+ default = None , description = "State object for continuing the conversation."
337+ )
338+ llm_output : Optional [dict ] = Field (
339+ default = None , description = "Additional LLM output data."
340+ )
341+ output_data : Optional [dict ] = Field (
342+ default = None , description = "Additional output data."
343+ )
344+ log : Optional [dict ] = Field (default = None , description = "Generation log data." )
345+
346+
347+ class Model (BaseModel ):
348+ id : str = Field (
349+ description = "The model identifier, which can be referenced in the API endpoints."
274350 )
351+ object : str = Field (
352+ default = "model" , description = "The object type, which is always 'model'."
353+ )
354+ created : int = Field (
355+ description = "The Unix timestamp (in seconds) of when the model was created."
356+ )
357+ owned_by : str = Field (
358+ default = "nemo-guardrails" , description = "The organization that owns the model."
359+ )
360+
361+
362+ class ModelsResponse (BaseModel ):
363+ object : str = Field (
364+ default = "list" , description = "The object type, which is always 'list'."
365+ )
366+ data : List [Model ] = Field (description = "The list of models." )
367+
368+
369+ @app .get (
370+ "/v1/models" ,
371+ response_model = ModelsResponse ,
372+ summary = "List available models" ,
373+ description = "Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format." ,
374+ )
375+ async def get_models ():
376+ """Returns the list of available models (guardrails configurations) in OpenAI-compatible format."""
377+
378+ # Use the same logic as get_rails_configs to find available configurations
379+ if app .single_config_mode :
380+ config_ids = [app .single_config_id ] if app .single_config_id else []
381+ else :
382+ config_ids = [
383+ f
384+ for f in os .listdir (app .rails_config_path )
385+ if os .path .isdir (os .path .join (app .rails_config_path , f ))
386+ and f [0 ] != "."
387+ and f [0 ] != "_"
388+ # Filter out all the configs for which there is no `config.yml` file.
389+ and (
390+ os .path .exists (os .path .join (app .rails_config_path , f , "config.yml" ))
391+ or os .path .exists (os .path .join (app .rails_config_path , f , "config.yaml" ))
392+ )
393+ ]
394+
395+ # Convert configurations to OpenAI model format
396+ models = []
397+ for config_id in config_ids :
398+ model = Model (
399+ id = config_id ,
400+ object = "model" ,
401+ created = int (time .time ()), # Use current time as created timestamp
402+ owned_by = "nemo-guardrails" ,
403+ )
404+ models .append (model )
405+
406+ return ModelsResponse (data = models )
275407
276408
277409@app .get (
@@ -401,13 +533,22 @@ async def chat_completion(body: RequestBody, request: Request):
401533 except ValueError as ex :
402534 log .exception (ex )
403535 return ResponseBody (
404- messages = [
405- {
406- "role" : "assistant" ,
407- "content" : f"Could not load the { config_ids } guardrails configuration. "
408- f"An internal error has occurred." ,
409- }
410- ]
536+ id = f"chatcmpl-{ uuid .uuid4 ()} " ,
537+ object = "chat.completion" ,
538+ created = int (time .time ()),
539+ model = config_ids [0 ] if config_ids else None ,
540+ choices = [
541+ Choice (
542+ index = 0 ,
543+ messages = {
544+ "content" : f"Could not load the { config_ids } guardrails configuration. "
545+ f"An internal error has occurred." ,
546+ "role" : "assistant" ,
547+ },
548+ finish_reason = "error" ,
549+ logprobs = None ,
550+ )
551+ ],
411552 )
412553
413554 try :
@@ -425,12 +566,21 @@ async def chat_completion(body: RequestBody, request: Request):
425566 # We make sure the `thread_id` meets the minimum complexity requirement.
426567 if len (body .thread_id ) < 16 :
427568 return ResponseBody (
428- messages = [
429- {
430- "role" : "assistant" ,
431- "content" : "The `thread_id` must have a minimum length of 16 characters." ,
432- }
433- ]
569+ id = f"chatcmpl-{ uuid .uuid4 ()} " ,
570+ object = "chat.completion" ,
571+ created = int (time .time ()),
572+ model = None ,
573+ choices = [
574+ Choice (
575+ index = 0 ,
576+ messages = {
577+ "content" : "The `thread_id` must have a minimum length of 16 characters." ,
578+ "role" : "assistant" ,
579+ },
580+ finish_reason = "error" ,
581+ logprobs = None ,
582+ )
583+ ],
434584 )
435585
436586 # Fetch the existing thread messages. For easier management, we prepend
@@ -441,6 +591,20 @@ async def chat_completion(body: RequestBody, request: Request):
441591 # And prepend them.
442592 messages = thread_messages + messages
443593
594+ generation_options = body .options
595+ if body .max_tokens :
596+ generation_options .max_tokens = body .max_tokens
597+ if body .temperature is not None :
598+ generation_options .temperature = body .temperature
599+ if body .top_p is not None :
600+ generation_options .top_p = body .top_p
601+ if body .stop :
602+ generation_options .stop = body .stop
603+ if body .presence_penalty is not None :
604+ generation_options .presence_penalty = body .presence_penalty
605+ if body .frequency_penalty is not None :
606+ generation_options .frequency_penalty = body .frequency_penalty
607+
444608 if (
445609 body .stream
446610 and llm_rails .config .streaming_supported
@@ -459,8 +623,6 @@ async def chat_completion(body: RequestBody, request: Request):
459623 )
460624 )
461625
462- # TODO: Add support for thread_ids in streaming mode
463-
464626 return StreamingResponse (streaming_handler )
465627 else :
466628 res = await llm_rails .generate_async (
@@ -483,21 +645,49 @@ async def chat_completion(body: RequestBody, request: Request):
483645 if body .thread_id and datastore is not None and datastore_key is not None :
484646 await datastore .set (datastore_key , json .dumps (messages + [bot_message ]))
485647
486- result = ResponseBody (messages = [bot_message ])
648+ # Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions
649+ response_kwargs = {
650+ "id" : f"chatcmpl-{ uuid .uuid4 ()} " ,
651+ "object" : "chat.completion" ,
652+ "created" : int (time .time ()),
653+ "model" : config_ids [0 ] if config_ids else None ,
654+ "choices" : [
655+ Choice (
656+ index = 0 ,
657+ messages = bot_message ,
658+ finish_reason = "stop" ,
659+ logprobs = None ,
660+ )
661+ ],
662+ }
487663
488- # If we have additional GenerationResponse fields, we return as well
664+ # If we have additional GenerationResponse fields, include them for backward compatibility
489665 if isinstance (res , GenerationResponse ):
490- result . llm_output = res .llm_output
491- result . output_data = res .output_data
492- result . log = res .log
493- result . state = res .state
666+ response_kwargs [ " llm_output" ] = res .llm_output
667+ response_kwargs [ " output_data" ] = res .output_data
668+ response_kwargs [ " log" ] = res .log
669+ response_kwargs [ " state" ] = res .state
494670
495- return result
671+ return ResponseBody ( ** response_kwargs )
496672
497673 except Exception as ex :
498674 log .exception (ex )
499675 return ResponseBody (
500- messages = [{"role" : "assistant" , "content" : "Internal server error." }]
676+ id = f"chatcmpl-{ uuid .uuid4 ()} " ,
677+ object = "chat.completion" ,
678+ created = int (time .time ()),
679+ model = None ,
680+ choices = [
681+ Choice (
682+ index = 0 ,
683+ messages = {
684+ "content" : "Internal server error" ,
685+ "role" : "assistant" ,
686+ },
687+ finish_reason = "error" ,
688+ logprobs = None ,
689+ )
690+ ],
501691 )
502692
503693
0 commit comments