Skip to content

Commit 5e3c8e9

Browse files
Add OpenAI docs and integration tests
1 parent 3e1d04a commit 5e3c8e9

File tree

5 files changed

+255
-67
lines changed

5 files changed

+255
-67
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
## OpenAI API Compatibility for NeMo Guardrails
2+
3+
NeMo Guardrails provides server-side compatibility with OpenAI API endpoints, enabling applications that use OpenAI clients to seamlessly integrate with NeMo Guardrails for adding guardrails to LLM interactions. Point your OpenAI client to `http://localhost:8000` (or your server URL) and use the standard `/v1/chat/completions` endpoint.
4+
5+
## Feature Support Matrix
6+
7+
The following table outlines which OpenAI API features are currently supported when using NeMo Guardrails:
8+
9+
| Feature | Status | Notes |
10+
| :------ | :----: | :---- |
11+
| **Basic Chat Completion** | ✔ Supported | Full support for standard chat completions with guardrails applied |
12+
| **Streaming Responses** | ✔ Supported | Server-Sent Events (SSE) streaming with `stream=true` |
13+
| **Multimodal Input** | ✖ Unsupported | Support for text and image inputs (vision models) with guardrails but not yet OpenAI compatible |
14+
| **Function Calling** | ✖ Unsupported | Not yet implemented; guardrails need structured output support |
15+
| **Tools** | ✖ Unsupported | Related to function calling; requires action flow integration |
16+
| **Response Format (JSON Mode)** | ✖ Unsupported | Structured output with guardrails requires additional validation logic |

nemoguardrails/server/api.py

Lines changed: 60 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
from fastapi import FastAPI, Request
3030
from fastapi.middleware.cors import CORSMiddleware
31-
from openai.types.chat.chat_completion import ChatCompletion, Choice
31+
from openai.types.chat.chat_completion import Choice
32+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
3233
from openai.types.model import Model
3334
from pydantic import BaseModel, Field, root_validator, validator
3435
from starlette.responses import StreamingResponse
@@ -191,88 +192,82 @@ async def root_handler():
191192
app.single_config_id = None
192193

193194

194-
class RequestBody(ChatCompletion):
195+
class RequestBody(BaseModel):
195196
config_id: Optional[str] = Field(
196197
default=os.getenv("DEFAULT_CONFIG_ID", None),
197198
description="The id of the configuration to be used. If not set, the default configuration will be used.",
198199
)
199200
config_ids: Optional[List[str]] = Field(
200201
default=None,
201-
description="The list of configuration ids to be used. "
202-
"If set, the configurations will be combined.",
203-
# alias="guardrails",
204-
validate_default=True,
202+
description="The ids of the configurations to be used. If not set, the default configuration will be used.",
205203
)
206204
thread_id: Optional[str] = Field(
207205
default=None,
208206
min_length=16,
209207
max_length=255,
210208
description="The id of an existing thread to which the messages should be added.",
211209
)
212-
model: Optional[str] = Field(
213-
default=None,
214-
description="The model used for the chat completion.",
210+
messages: Optional[List[dict]] = Field(
211+
default=None, description="The list of messages in the current conversation."
215212
)
216-
id: Optional[str] = Field(
213+
context: Optional[dict] = Field(
217214
default=None,
218-
description="The id of the chat completion.",
215+
description="Additional context data to be added to the conversation.",
216+
)
217+
stream: Optional[bool] = Field(
218+
default=False,
219+
description="If set, partial message deltas will be sent, like in ChatGPT. "
220+
"Tokens will be sent as data-only server-sent events as they become "
221+
"available, with the stream terminated by a data: [DONE] message.",
219222
)
220-
object: Optional[str] = Field(
221-
default="chat.completion",
222-
description="The object type, which is always chat.completion",
223+
options: GenerationOptions = Field(
224+
default_factory=GenerationOptions,
225+
description="Additional options for controlling the generation.",
223226
)
224-
created: Optional[int] = Field(
227+
state: Optional[dict] = Field(
225228
default=None,
226-
description="The Unix timestamp (in seconds) of when the chat completion was created.",
229+
description="A state object that should be used to continue the interaction.",
227230
)
228-
choices: Optional[List[Choice]] = Field(
231+
# Standard OpenAI completion parameters
232+
model: Optional[str] = Field(
229233
default=None,
230-
description="The list of choices for the chat completion.",
234+
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
231235
)
232236
max_tokens: Optional[int] = Field(
233237
default=None,
234238
description="The maximum number of tokens to generate.",
235239
)
236240
temperature: Optional[float] = Field(
237241
default=None,
238-
description="The temperature to use for the chat completion.",
242+
description="Sampling temperature to use.",
239243
)
240244
top_p: Optional[float] = Field(
241245
default=None,
242-
description="The top p to use for the chat completion.",
246+
description="Top-p sampling parameter.",
243247
)
244-
stop: Optional[Union[str, List[str]]] = Field(
248+
stop: Optional[str] = Field(
245249
default=None,
246-
description="The stop sequences to use for the chat completion.",
250+
description="Stop sequences.",
247251
)
248252
presence_penalty: Optional[float] = Field(
249253
default=None,
250-
description="The presence penalty to use for the chat completion.",
254+
description="Presence penalty parameter.",
251255
)
252256
frequency_penalty: Optional[float] = Field(
253257
default=None,
254-
description="The frequency penalty to use for the chat completion.",
258+
description="Frequency penalty parameter.",
255259
)
256-
messages: Optional[List[dict]] = Field(
257-
default=None, description="The list of messages in the current conversation."
258-
)
259-
context: Optional[dict] = Field(
260+
function_call: Optional[dict] = Field(
260261
default=None,
261-
description="Additional context data to be added to the conversation.",
262+
description="Function call parameter.",
262263
)
263-
stream: Optional[bool] = Field(
264-
default=False,
265-
description="If set, partial message deltas will be sent, like in ChatGPT. "
266-
"Tokens will be sent as data-only server-sent events as they become "
267-
"available, with the stream terminated by a data: [DONE] message.",
268-
)
269-
options: GenerationOptions = Field(
270-
default_factory=GenerationOptions,
271-
description="Additional options for controlling the generation.",
264+
logit_bias: Optional[dict] = Field(
265+
default=None,
266+
description="Logit bias parameter.",
272267
)
273-
state: Optional[dict] = Field(
268+
log_probs: Optional[bool] = Field(
274269
default=None,
275-
description="A state object that should be used to continue the interaction.",
270+
description="Log probabilities parameter.",
276271
)
277272

278273
@root_validator(pre=True)
@@ -537,16 +532,16 @@ async def chat_completion(body: RequestBody, request: Request):
537532
id=f"chatcmpl-{uuid.uuid4()}",
538533
object="chat.completion",
539534
created=int(time.time()),
540-
model=config_ids[0] if config_ids else None,
535+
model=config_ids[0] if config_ids else "unknown",
541536
choices=[
542537
Choice(
543538
index=0,
544-
message={
545-
"content": f"Could not load the {config_ids} guardrails configuration. "
539+
message=ChatCompletionMessage(
540+
content=f"Could not load the {config_ids} guardrails configuration. "
546541
f"An internal error has occurred.",
547-
"role": "assistant",
548-
},
549-
finish_reason="error",
542+
role="assistant",
543+
),
544+
finish_reason="stop",
550545
logprobs=None,
551546
)
552547
],
@@ -570,15 +565,15 @@ async def chat_completion(body: RequestBody, request: Request):
570565
id=f"chatcmpl-{uuid.uuid4()}",
571566
object="chat.completion",
572567
created=int(time.time()),
573-
model=None,
568+
model=config_ids[0] if config_ids else "unknown",
574569
choices=[
575570
Choice(
576571
index=0,
577-
message={
578-
"content": "The `thread_id` must have a minimum length of 16 characters.",
579-
"role": "assistant",
580-
},
581-
finish_reason="error",
572+
message=ChatCompletionMessage(
573+
content="The `thread_id` must have a minimum length of 16 characters.",
574+
role="assistant",
575+
),
576+
finish_reason="stop",
582577
logprobs=None,
583578
)
584579
],
@@ -625,7 +620,7 @@ async def chat_completion(body: RequestBody, request: Request):
625620
llm_rails.generate_async(
626621
messages=messages,
627622
streaming_handler=streaming_handler,
628-
options=body.options,
623+
options=generation_options,
629624
state=body.state,
630625
)
631626
)
@@ -638,7 +633,7 @@ async def chat_completion(body: RequestBody, request: Request):
638633
)
639634
else:
640635
res = await llm_rails.generate_async(
641-
messages=messages, options=body.options, state=body.state
636+
messages=messages, options=generation_options, state=body.state
642637
)
643638

644639
if isinstance(res, GenerationResponse):
@@ -662,11 +657,14 @@ async def chat_completion(body: RequestBody, request: Request):
662657
"id": f"chatcmpl-{uuid.uuid4()}",
663658
"object": "chat.completion",
664659
"created": int(time.time()),
665-
"model": config_ids[0] if config_ids else None,
660+
"model": config_ids[0] if config_ids else "unknown",
666661
"choices": [
667662
Choice(
668663
index=0,
669-
message=bot_message,
664+
message=ChatCompletionMessage(
665+
role="assistant",
666+
content=bot_message["content"],
667+
),
670668
finish_reason="stop",
671669
logprobs=None,
672670
)
@@ -688,15 +686,15 @@ async def chat_completion(body: RequestBody, request: Request):
688686
id=f"chatcmpl-{uuid.uuid4()}",
689687
object="chat.completion",
690688
created=int(time.time()),
691-
model=None,
689+
model="unknown",
692690
choices=[
693691
Choice(
694692
index=0,
695-
message={
696-
"content": "Internal server error",
697-
"role": "assistant",
698-
},
699-
finish_reason="error",
693+
message=ChatCompletionMessage(
694+
content="Internal server error",
695+
role="assistant",
696+
),
697+
finish_reason="stop",
700698
logprobs=None,
701699
)
702700
],

nemoguardrails/server/schemas/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from typing import List, Optional
1919

20-
from openai.types.chat.chat_completion import ChatCompletion, Choice
20+
from openai.types.chat.chat_completion import ChatCompletion
2121
from openai.types.model import Model
2222
from pydantic import BaseModel, Field
2323

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)