Skip to content

Commit 908f066

Browse files
Extend existing OpenAI types and add support for streaming chat completion
1 parent 59c1644 commit 908f066

File tree

6 files changed

+484
-113
lines changed

6 files changed

+484
-113
lines changed

nemoguardrails/server/api.py

Lines changed: 121 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
import uuid
2525
import warnings
2626
from contextlib import asynccontextmanager
27-
from typing import Any, Callable, List, Optional
27+
from typing import Any, AsyncIterator, Callable, List, Optional, Union
2828

2929
from fastapi import FastAPI, Request
3030
from fastapi.middleware.cors import CORSMiddleware
31+
from openai.types.chat.chat_completion import ChatCompletion, Choice
32+
from openai.types.model import Model
3133
from pydantic import Field, root_validator, validator
3234
from starlette.responses import StreamingResponse
3335
from starlette.staticfiles import StaticFiles
@@ -36,10 +38,7 @@
3638
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
3739
from nemoguardrails.server.datastore.datastore import DataStore
3840
from nemoguardrails.server.schemas.openai import (
39-
Choice,
40-
Model,
4141
ModelsResponse,
42-
OpenAIRequestFields,
4342
ResponseBody,
4443
)
4544
from nemoguardrails.streaming import StreamingHandler
@@ -195,7 +194,7 @@ async def root_handler():
195194
app.single_config_id = None
196195

197196

198-
class RequestBody(OpenAIRequestFields):
197+
class RequestBody(ChatCompletion):
199198
config_id: Optional[str] = Field(
200199
default=os.getenv("DEFAULT_CONFIG_ID", None),
201200
description="The id of the configuration to be used. If not set, the default configuration will be used.",
@@ -213,6 +212,50 @@ class RequestBody(OpenAIRequestFields):
213212
max_length=255,
214213
description="The id of an existing thread to which the messages should be added.",
215214
)
215+
model: Optional[str] = Field(
216+
default=None,
217+
description="The model used for the chat completion.",
218+
)
219+
id: Optional[str] = Field(
220+
default=None,
221+
description="The id of the chat completion.",
222+
)
223+
object: Optional[str] = Field(
224+
default="chat.completion",
225+
description="The object type, which is always chat.completion",
226+
)
227+
created: Optional[int] = Field(
228+
default=None,
229+
description="The Unix timestamp (in seconds) of when the chat completion was created.",
230+
)
231+
choices: Optional[List[Choice]] = Field(
232+
default=None,
233+
description="The list of choices for the chat completion.",
234+
)
235+
max_tokens: Optional[int] = Field(
236+
default=None,
237+
description="The maximum number of tokens to generate.",
238+
)
239+
temperature: Optional[float] = Field(
240+
default=None,
241+
description="The temperature to use for the chat completion.",
242+
)
243+
top_p: Optional[float] = Field(
244+
default=None,
245+
description="The top p to use for the chat completion.",
246+
)
247+
stop: Optional[Union[str, List[str]]] = Field(
248+
default=None,
249+
description="The stop sequences to use for the chat completion.",
250+
)
251+
presence_penalty: Optional[float] = Field(
252+
default=None,
253+
description="The presence penalty to use for the chat completion.",
254+
)
255+
frequency_penalty: Optional[float] = Field(
256+
default=None,
257+
description="The frequency penalty to use for the chat completion.",
258+
)
216259
messages: Optional[List[dict]] = Field(
217260
default=None, description="The list of messages in the current conversation."
218261
)
@@ -392,6 +435,73 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
392435
return llm_rails
393436

394437

438+
async def _format_streaming_response(
439+
streaming_handler: StreamingHandler, model_name: Optional[str]
440+
) -> AsyncIterator[str]:
441+
while True:
442+
try:
443+
chunk = await streaming_handler.__anext__()
444+
except StopAsyncIteration:
445+
# When the stream ends, yield the [DONE] message
446+
yield "data: [DONE]\n\n"
447+
break
448+
449+
# Determine the payload format based on chunk type
450+
if isinstance(chunk, dict):
451+
# If chunk is a dict, wrap it in OpenAI chunk format with delta
452+
payload = {
453+
"id": None,
454+
"object": "chat.completion.chunk",
455+
"created": int(time.time()),
456+
"model": model_name,
457+
"choices": [
458+
{
459+
"delta": chunk,
460+
"index": None,
461+
"finish_reason": None,
462+
}
463+
],
464+
}
465+
elif isinstance(chunk, str):
466+
try:
467+
# Try parsing as JSON - if it parses, it might be a pre-formed payload
468+
payload = json.loads(chunk)
469+
except Exception:
470+
# treat as plain text content token
471+
payload = {
472+
"id": None,
473+
"object": "chat.completion.chunk",
474+
"created": int(time.time()),
475+
"model": model_name,
476+
"choices": [
477+
{
478+
"delta": {"content": chunk},
479+
"index": None,
480+
"finish_reason": None,
481+
}
482+
],
483+
}
484+
else:
485+
# For any other type, treat as plain content
486+
payload = {
487+
"id": None,
488+
"object": "chat.completion.chunk",
489+
"created": int(time.time()),
490+
"model": model_name,
491+
"choices": [
492+
{
493+
"delta": {"content": str(chunk)},
494+
"index": None,
495+
"finish_reason": None,
496+
}
497+
],
498+
}
499+
500+
# Send the payload as JSON
501+
data = json.dumps(payload, ensure_ascii=False)
502+
yield f"data: {data}\n\n"
503+
504+
395505
@app.post(
396506
"/v1/chat/completions",
397507
response_model=ResponseBody,
@@ -523,7 +633,12 @@ async def chat_completion(body: RequestBody, request: Request):
523633
)
524634
)
525635

526-
return StreamingResponse(streaming_handler)
636+
return StreamingResponse(
637+
_format_streaming_response(
638+
streaming_handler, model_name=config_ids[0] if config_ids else None
639+
),
640+
media_type="text/event-stream",
641+
)
527642
else:
528643
res = await llm_rails.generate_async(
529644
messages=messages, options=body.options, state=body.state

nemoguardrails/server/schemas/openai.py

Lines changed: 7 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -15,96 +15,19 @@
1515

1616
"""OpenAI API schema definitions for the NeMo Guardrails server."""
1717

18-
from typing import List, Optional, Union
18+
from typing import List, Optional
1919

20+
from openai.types.chat.chat_completion import (
21+
ChatCompletion,
22+
Choice
23+
)
24+
from openai.types.model import Model
2025
from pydantic import BaseModel, Field
2126

2227

23-
class OpenAIRequestFields(BaseModel):
24-
"""OpenAI API request fields that can be mixed into other request schemas."""
25-
26-
# Standard OpenAI completion parameters
27-
model: Optional[str] = Field(
28-
default=None,
29-
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
30-
)
31-
max_tokens: Optional[int] = Field(
32-
default=None,
33-
description="The maximum number of tokens to generate.",
34-
)
35-
temperature: Optional[float] = Field(
36-
default=None,
37-
description="Sampling temperature to use.",
38-
)
39-
top_p: Optional[float] = Field(
40-
default=None,
41-
description="Top-p sampling parameter.",
42-
)
43-
stop: Optional[Union[str, List[str]]] = Field(
44-
default=None,
45-
description="Stop sequences.",
46-
)
47-
presence_penalty: Optional[float] = Field(
48-
default=None,
49-
description="Presence penalty parameter.",
50-
)
51-
frequency_penalty: Optional[float] = Field(
52-
default=None,
53-
description="Frequency penalty parameter.",
54-
)
55-
function_call: Optional[dict] = Field(
56-
default=None,
57-
description="Function call parameter.",
58-
)
59-
logit_bias: Optional[dict] = Field(
60-
default=None,
61-
description="Logit bias parameter.",
62-
)
63-
log_probs: Optional[bool] = Field(
64-
default=None,
65-
description="Log probabilities parameter.",
66-
)
67-
68-
69-
class Choice(BaseModel):
70-
"""OpenAI API choice structure in chat completion responses."""
71-
72-
index: Optional[int] = Field(
73-
default=None, description="The index of the choice in the list of choices."
74-
)
75-
message: Optional[dict] = Field(
76-
default=None, description="The message of the choice"
77-
)
78-
logprobs: Optional[dict] = Field(
79-
default=None, description="The log probabilities of the choice"
80-
)
81-
finish_reason: Optional[str] = Field(
82-
default=None, description="The reason the model stopped generating tokens."
83-
)
84-
85-
86-
class ResponseBody(BaseModel):
28+
class ResponseBody(ChatCompletion):
8729
"""OpenAI API response body with NeMo-Guardrails extensions."""
8830

89-
# OpenAI API fields
90-
id: Optional[str] = Field(
91-
default=None, description="A unique identifier for the chat completion."
92-
)
93-
object: str = Field(
94-
default="chat.completion",
95-
description="The object type, which is always chat.completion",
96-
)
97-
created: Optional[int] = Field(
98-
default=None,
99-
description="The Unix timestamp (in seconds) of when the chat completion was created.",
100-
)
101-
model: Optional[str] = Field(
102-
default=None, description="The model used for the chat completion."
103-
)
104-
choices: Optional[List[Choice]] = Field(
105-
default=None, description="A list of chat completion choices."
106-
)
107-
# NeMo-Guardrails specific fields for backward compatibility
10831
state: Optional[dict] = Field(
10932
default=None, description="State object for continuing the conversation."
11033
)
@@ -117,23 +40,6 @@ class ResponseBody(BaseModel):
11740
log: Optional[dict] = Field(default=None, description="Generation log data.")
11841

11942

120-
class Model(BaseModel):
121-
"""OpenAI API model representation."""
122-
123-
id: str = Field(
124-
description="The model identifier, which can be referenced in the API endpoints."
125-
)
126-
object: str = Field(
127-
default="model", description="The object type, which is always 'model'."
128-
)
129-
created: int = Field(
130-
description="The Unix timestamp (in seconds) of when the model was created."
131-
)
132-
owned_by: str = Field(
133-
default="nemo-guardrails", description="The organization that owns the model."
134-
)
135-
136-
13743
class ModelsResponse(BaseModel):
13844
"""OpenAI API models list response."""
13945

nemoguardrails/streaming.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,18 +174,39 @@ async def __anext__(self):
174174

175175
async def _process(
176176
self,
177-
chunk: Union[str, object],
177+
chunk: Union[str, dict, object],
178178
generation_info: Optional[Dict[str, Any]] = None,
179179
):
180-
"""Process a chunk of text.
180+
"""Process a chunk of text or dict.
181181
182182
If we're in buffering mode, record the text.
183183
Otherwise, update the full completion, check for stop tokens, and enqueue the chunk.
184+
Dict chunks bypass completion tracking and go directly to the queue.
184185
"""
185186

186187
if self.include_generation_metadata and generation_info:
187188
self.current_generation_info = generation_info
188189

190+
# Dict chunks bypass buffering and completion tracking
191+
if isinstance(chunk, dict):
192+
if self.pipe_to:
193+
asyncio.create_task(self.pipe_to.push_chunk(chunk))
194+
else:
195+
if self.include_generation_metadata:
196+
await self.queue.put(
197+
{
198+
"text": chunk,
199+
"generation_info": (
200+
self.current_generation_info.copy()
201+
if self.current_generation_info
202+
else {}
203+
),
204+
}
205+
)
206+
else:
207+
await self.queue.put(chunk)
208+
return
209+
189210
if self.enable_buffer:
190211
if chunk is not END_OF_STREAM:
191212
self.buffer += chunk if chunk is not None else ""
@@ -259,10 +280,28 @@ async def _process(
259280

260281
async def push_chunk(
261282
self,
262-
chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, None],
283+
chunk: Union[
284+
str,
285+
dict,
286+
GenerationChunk,
287+
AIMessageChunk,
288+
ChatGenerationChunk,
289+
None,
290+
object,
291+
],
263292
generation_info: Optional[Dict[str, Any]] = None,
264293
):
265-
"""Push a new chunk to the stream."""
294+
"""Push a new chunk to the stream.
295+
296+
Args:
297+
chunk: The chunk to push. Can be:
298+
- str: Plain text content
299+
- dict: Dictionary with fields like role, content, etc.
300+
- GenerationChunk/AIMessageChunk/ChatGenerationChunk: LangChain chunk types
301+
- None: Signals end of stream (converted to END_OF_STREAM)
302+
- object: END_OF_STREAM sentinel
303+
generation_info: Optional metadata about the generation
304+
"""
266305

267306
# if generation_info is not explicitly passed,
268307
# try to get it from the chunk itself if it's a GenerationChunk or ChatGenerationChunk
@@ -288,6 +327,9 @@ async def push_chunk(
288327
elif isinstance(chunk, str):
289328
# empty string is a valid chunk and should be processed normally
290329
pass
330+
elif isinstance(chunk, dict):
331+
# plain dict chunks are allowed (e.g., for OpenAI-compatible streaming)
332+
pass
291333
else:
292334
raise Exception(f"Unsupported chunk type: {chunk.__class__.__name__}")
293335

@@ -298,6 +340,11 @@ async def push_chunk(
298340
if self.include_generation_metadata and generation_info:
299341
self.current_generation_info = generation_info
300342

343+
# Dict chunks bypass prefix/suffix processing and go directly to _process
344+
if isinstance(chunk, dict):
345+
await self._process(chunk, generation_info)
346+
return
347+
301348
# Process prefix: accumulate until the expected prefix is received, then remove it.
302349
if self.prefix:
303350
if chunk is not None and chunk is not END_OF_STREAM:

0 commit comments

Comments
 (0)