Skip to content

Commit 1e00d3a

Browse files
feat: Modify endpoints for OpenAPI compatibility
1 parent b5b7579 commit 1e00d3a

File tree

4 files changed

+298
-46
lines changed

4 files changed

+298
-46
lines changed

nemoguardrails/server/api.py

Lines changed: 224 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os.path
2121
import re
2222
import time
23+
import uuid
2324
import warnings
2425
from contextlib import asynccontextmanager
2526
from typing import Any, Callable, List, Optional
@@ -229,10 +230,53 @@ class RequestBody(BaseModel):
229230
default=None,
230231
description="A state object that should be used to continue the interaction.",
231232
)
233+
# Standard OpenAI completion parameters
234+
model: Optional[str] = Field(
235+
default=None,
236+
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
237+
)
238+
max_tokens: Optional[int] = Field(
239+
default=None,
240+
description="The maximum number of tokens to generate.",
241+
)
242+
temperature: Optional[float] = Field(
243+
default=None,
244+
description="Sampling temperature to use.",
245+
)
246+
top_p: Optional[float] = Field(
247+
default=None,
248+
description="Top-p sampling parameter.",
249+
)
250+
stop: Optional[str] = Field(
251+
default=None,
252+
description="Stop sequences.",
253+
)
254+
presence_penalty: Optional[float] = Field(
255+
default=None,
256+
description="Presence penalty parameter.",
257+
)
258+
frequency_penalty: Optional[float] = Field(
259+
default=None,
260+
description="Frequency penalty parameter.",
261+
)
262+
function_call: Optional[dict] = Field(
263+
default=None,
264+
description="Function call parameter.",
265+
)
266+
logit_bias: Optional[dict] = Field(
267+
default=None,
268+
description="Logit bias parameter.",
269+
)
270+
log_probs: Optional[bool] = Field(
271+
default=None,
272+
description="Log probabilities parameter.",
273+
)
232274

233275
@root_validator(pre=True)
234276
def ensure_config_id(cls, data: Any) -> Any:
235277
if isinstance(data, dict):
278+
if data.get("model") is not None and data.get("config_id") is None:
279+
data["config_id"] = data["model"]
236280
if data.get("config_id") is not None and data.get("config_ids") is not None:
237281
raise ValueError(
238282
"Only one of config_id or config_ids should be specified"
@@ -253,25 +297,113 @@ def ensure_config_ids(cls, v, values):
253297
return v
254298

255299

300+
class Choice(BaseModel):
301+
index: Optional[int] = Field(
302+
default=None, description="The index of the choice in the list of choices."
303+
)
304+
messages: Optional[dict] = Field(
305+
default=None, description="The message of the choice"
306+
)
307+
logprobs: Optional[dict] = Field(
308+
default=None, description="The log probabilities of the choice"
309+
)
310+
finish_reason: Optional[str] = Field(
311+
default=None, description="The reason the model stopped generating tokens."
312+
)
313+
314+
256315
class ResponseBody(BaseModel):
257-
messages: Optional[List[dict]] = Field(
258-
default=None, description="The new messages in the conversation"
316+
# OpenAI-compatible fields
317+
id: Optional[str] = Field(
318+
default=None, description="A unique identifier for the chat completion."
259319
)
260-
llm_output: Optional[dict] = Field(
261-
default=None,
262-
description="Contains any additional output coming from the LLM.",
320+
object: str = Field(
321+
default="chat.completion",
322+
description="The object type, which is always chat.completion",
263323
)
264-
output_data: Optional[dict] = Field(
324+
created: Optional[int] = Field(
265325
default=None,
266-
description="The output data, i.e. a dict with the values corresponding to the `output_vars`.",
326+
description="The Unix timestamp (in seconds) of when the chat completion was created.",
327+
)
328+
model: Optional[str] = Field(
329+
default=None, description="The model used for the chat completion."
267330
)
268-
log: Optional[GenerationLog] = Field(
269-
default=None, description="Additional logging information."
331+
choices: Optional[List[Choice]] = Field(
332+
default=None, description="A list of chat completion choices."
270333
)
334+
# NeMo-Guardrails specific fields for backward compatibility
271335
state: Optional[dict] = Field(
272-
default=None,
273-
description="A state object that should be used to continue the interaction in the future.",
336+
default=None, description="State object for continuing the conversation."
337+
)
338+
llm_output: Optional[dict] = Field(
339+
default=None, description="Additional LLM output data."
340+
)
341+
output_data: Optional[dict] = Field(
342+
default=None, description="Additional output data."
343+
)
344+
log: Optional[dict] = Field(default=None, description="Generation log data.")
345+
346+
347+
class Model(BaseModel):
348+
id: str = Field(
349+
description="The model identifier, which can be referenced in the API endpoints."
274350
)
351+
object: str = Field(
352+
default="model", description="The object type, which is always 'model'."
353+
)
354+
created: int = Field(
355+
description="The Unix timestamp (in seconds) of when the model was created."
356+
)
357+
owned_by: str = Field(
358+
default="nemo-guardrails", description="The organization that owns the model."
359+
)
360+
361+
362+
class ModelsResponse(BaseModel):
363+
object: str = Field(
364+
default="list", description="The object type, which is always 'list'."
365+
)
366+
data: List[Model] = Field(description="The list of models.")
367+
368+
369+
@app.get(
370+
"/v1/models",
371+
response_model=ModelsResponse,
372+
summary="List available models",
373+
description="Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format.",
374+
)
375+
async def get_models():
376+
"""Returns the list of available models (guardrails configurations) in OpenAI-compatible format."""
377+
378+
# Use the same logic as get_rails_configs to find available configurations
379+
if app.single_config_mode:
380+
config_ids = [app.single_config_id] if app.single_config_id else []
381+
else:
382+
config_ids = [
383+
f
384+
for f in os.listdir(app.rails_config_path)
385+
if os.path.isdir(os.path.join(app.rails_config_path, f))
386+
and f[0] != "."
387+
and f[0] != "_"
388+
# Filter out all the configs for which there is no `config.yml` file.
389+
and (
390+
os.path.exists(os.path.join(app.rails_config_path, f, "config.yml"))
391+
or os.path.exists(os.path.join(app.rails_config_path, f, "config.yaml"))
392+
)
393+
]
394+
395+
# Convert configurations to OpenAI model format
396+
models = []
397+
for config_id in config_ids:
398+
model = Model(
399+
id=config_id,
400+
object="model",
401+
created=int(time.time()), # Use current time as created timestamp
402+
owned_by="nemo-guardrails",
403+
)
404+
models.append(model)
405+
406+
return ModelsResponse(data=models)
275407

276408

277409
@app.get(
@@ -401,13 +533,22 @@ async def chat_completion(body: RequestBody, request: Request):
401533
except ValueError as ex:
402534
log.exception(ex)
403535
return ResponseBody(
404-
messages=[
405-
{
406-
"role": "assistant",
407-
"content": f"Could not load the {config_ids} guardrails configuration. "
408-
f"An internal error has occurred.",
409-
}
410-
]
536+
id=f"chatcmpl-{uuid.uuid4()}",
537+
object="chat.completion",
538+
created=int(time.time()),
539+
model=config_ids[0] if config_ids else None,
540+
choices=[
541+
Choice(
542+
index=0,
543+
messages={
544+
"content": f"Could not load the {config_ids} guardrails configuration. "
545+
f"An internal error has occurred.",
546+
"role": "assistant",
547+
},
548+
finish_reason="error",
549+
logprobs=None,
550+
)
551+
],
411552
)
412553

413554
try:
@@ -425,12 +566,21 @@ async def chat_completion(body: RequestBody, request: Request):
425566
# We make sure the `thread_id` meets the minimum complexity requirement.
426567
if len(body.thread_id) < 16:
427568
return ResponseBody(
428-
messages=[
429-
{
430-
"role": "assistant",
431-
"content": "The `thread_id` must have a minimum length of 16 characters.",
432-
}
433-
]
569+
id=f"chatcmpl-{uuid.uuid4()}",
570+
object="chat.completion",
571+
created=int(time.time()),
572+
model=None,
573+
choices=[
574+
Choice(
575+
index=0,
576+
messages={
577+
"content": "The `thread_id` must have a minimum length of 16 characters.",
578+
"role": "assistant",
579+
},
580+
finish_reason="error",
581+
logprobs=None,
582+
)
583+
],
434584
)
435585

436586
# Fetch the existing thread messages. For easier management, we prepend
@@ -441,6 +591,20 @@ async def chat_completion(body: RequestBody, request: Request):
441591
# And prepend them.
442592
messages = thread_messages + messages
443593

594+
generation_options = body.options
595+
if body.max_tokens:
596+
generation_options.max_tokens = body.max_tokens
597+
if body.temperature is not None:
598+
generation_options.temperature = body.temperature
599+
if body.top_p is not None:
600+
generation_options.top_p = body.top_p
601+
if body.stop:
602+
generation_options.stop = body.stop
603+
if body.presence_penalty is not None:
604+
generation_options.presence_penalty = body.presence_penalty
605+
if body.frequency_penalty is not None:
606+
generation_options.frequency_penalty = body.frequency_penalty
607+
444608
if (
445609
body.stream
446610
and llm_rails.config.streaming_supported
@@ -459,8 +623,6 @@ async def chat_completion(body: RequestBody, request: Request):
459623
)
460624
)
461625

462-
# TODO: Add support for thread_ids in streaming mode
463-
464626
return StreamingResponse(streaming_handler)
465627
else:
466628
res = await llm_rails.generate_async(
@@ -483,21 +645,49 @@ async def chat_completion(body: RequestBody, request: Request):
483645
if body.thread_id and datastore is not None and datastore_key is not None:
484646
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))
485647

486-
result = ResponseBody(messages=[bot_message])
648+
# Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions
649+
response_kwargs = {
650+
"id": f"chatcmpl-{uuid.uuid4()}",
651+
"object": "chat.completion",
652+
"created": int(time.time()),
653+
"model": config_ids[0] if config_ids else None,
654+
"choices": [
655+
Choice(
656+
index=0,
657+
messages=bot_message,
658+
finish_reason="stop",
659+
logprobs=None,
660+
)
661+
],
662+
}
487663

488-
# If we have additional GenerationResponse fields, we return as well
664+
# If we have additional GenerationResponse fields, include them for backward compatibility
489665
if isinstance(res, GenerationResponse):
490-
result.llm_output = res.llm_output
491-
result.output_data = res.output_data
492-
result.log = res.log
493-
result.state = res.state
666+
response_kwargs["llm_output"] = res.llm_output
667+
response_kwargs["output_data"] = res.output_data
668+
response_kwargs["log"] = res.log
669+
response_kwargs["state"] = res.state
494670

495-
return result
671+
return ResponseBody(**response_kwargs)
496672

497673
except Exception as ex:
498674
log.exception(ex)
499675
return ResponseBody(
500-
messages=[{"role": "assistant", "content": "Internal server error."}]
676+
id=f"chatcmpl-{uuid.uuid4()}",
677+
object="chat.completion",
678+
created=int(time.time()),
679+
model=None,
680+
choices=[
681+
Choice(
682+
index=0,
683+
messages={
684+
"content": "Internal server error",
685+
"role": "assistant",
686+
},
687+
finish_reason="error",
688+
logprobs=None,
689+
)
690+
],
501691
)
502692

503693

0 commit comments

Comments
 (0)