Skip to content

Commit 3e1d04a

Browse files
Extend existing OpenAI types and add support for streaming chat completion
1 parent c58771c commit 3e1d04a

File tree

7 files changed

+496
-119
lines changed

7 files changed

+496
-119
lines changed

nemoguardrails/server/api.py

Lines changed: 123 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,20 @@
2424
import uuid
2525
import warnings
2626
from contextlib import asynccontextmanager
27-
from typing import Any, Callable, List, Optional
27+
from typing import Any, AsyncIterator, Callable, List, Optional, Union
2828

2929
from fastapi import FastAPI, Request
3030
from fastapi.middleware.cors import CORSMiddleware
31-
from pydantic import Field, root_validator, validator
31+
from openai.types.chat.chat_completion import ChatCompletion, Choice
32+
from openai.types.model import Model
33+
from pydantic import BaseModel, Field, root_validator, validator
3234
from starlette.responses import StreamingResponse
3335
from starlette.staticfiles import StaticFiles
3436

3537
from nemoguardrails import LLMRails, RailsConfig, utils
3638
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
3739
from nemoguardrails.server.datastore.datastore import DataStore
38-
from nemoguardrails.server.schemas.openai import (
39-
Choice,
40-
Model,
41-
ModelsResponse,
42-
OpenAIRequestFields,
43-
ResponseBody,
44-
)
40+
from nemoguardrails.server.schemas.openai import ModelsResponse, ResponseBody
4541
from nemoguardrails.streaming import StreamingHandler
4642

4743
logging.basicConfig(level=logging.INFO)
@@ -195,7 +191,7 @@ async def root_handler():
195191
app.single_config_id = None
196192

197193

198-
class RequestBody(OpenAIRequestFields):
194+
class RequestBody(ChatCompletion):
199195
config_id: Optional[str] = Field(
200196
default=os.getenv("DEFAULT_CONFIG_ID", None),
201197
description="The id of the configuration to be used. If not set, the default configuration will be used.",
@@ -213,6 +209,50 @@ class RequestBody(OpenAIRequestFields):
213209
max_length=255,
214210
description="The id of an existing thread to which the messages should be added.",
215211
)
212+
model: Optional[str] = Field(
213+
default=None,
214+
description="The model used for the chat completion.",
215+
)
216+
id: Optional[str] = Field(
217+
default=None,
218+
description="The id of the chat completion.",
219+
)
220+
object: Optional[str] = Field(
221+
default="chat.completion",
222+
description="The object type, which is always chat.completion",
223+
)
224+
created: Optional[int] = Field(
225+
default=None,
226+
description="The Unix timestamp (in seconds) of when the chat completion was created.",
227+
)
228+
choices: Optional[List[Choice]] = Field(
229+
default=None,
230+
description="The list of choices for the chat completion.",
231+
)
232+
max_tokens: Optional[int] = Field(
233+
default=None,
234+
description="The maximum number of tokens to generate.",
235+
)
236+
temperature: Optional[float] = Field(
237+
default=None,
238+
description="The temperature to use for the chat completion.",
239+
)
240+
top_p: Optional[float] = Field(
241+
default=None,
242+
description="The top p to use for the chat completion.",
243+
)
244+
stop: Optional[Union[str, List[str]]] = Field(
245+
default=None,
246+
description="The stop sequences to use for the chat completion.",
247+
)
248+
presence_penalty: Optional[float] = Field(
249+
default=None,
250+
description="The presence penalty to use for the chat completion.",
251+
)
252+
frequency_penalty: Optional[float] = Field(
253+
default=None,
254+
description="The frequency penalty to use for the chat completion.",
255+
)
216256
messages: Optional[List[dict]] = Field(
217257
default=None, description="The list of messages in the current conversation."
218258
)
@@ -392,6 +432,73 @@ def _get_rails(config_ids: List[str]) -> LLMRails:
392432
return llm_rails
393433

394434

435+
async def _format_streaming_response(
436+
streaming_handler: StreamingHandler, model_name: Optional[str]
437+
) -> AsyncIterator[str]:
438+
while True:
439+
try:
440+
chunk = await streaming_handler.__anext__()
441+
except StopAsyncIteration:
442+
# When the stream ends, yield the [DONE] message
443+
yield "data: [DONE]\n\n"
444+
break
445+
446+
# Determine the payload format based on chunk type
447+
if isinstance(chunk, dict):
448+
# If chunk is a dict, wrap it in OpenAI chunk format with delta
449+
payload = {
450+
"id": None,
451+
"object": "chat.completion.chunk",
452+
"created": int(time.time()),
453+
"model": model_name,
454+
"choices": [
455+
{
456+
"delta": chunk,
457+
"index": None,
458+
"finish_reason": None,
459+
}
460+
],
461+
}
462+
elif isinstance(chunk, str):
463+
try:
464+
# Try parsing as JSON - if it parses, it might be a pre-formed payload
465+
payload = json.loads(chunk)
466+
except Exception:
467+
# treat as plain text content token
468+
payload = {
469+
"id": None,
470+
"object": "chat.completion.chunk",
471+
"created": int(time.time()),
472+
"model": model_name,
473+
"choices": [
474+
{
475+
"delta": {"content": chunk},
476+
"index": None,
477+
"finish_reason": None,
478+
}
479+
],
480+
}
481+
else:
482+
# For any other type, treat as plain content
483+
payload = {
484+
"id": None,
485+
"object": "chat.completion.chunk",
486+
"created": int(time.time()),
487+
"model": model_name,
488+
"choices": [
489+
{
490+
"delta": {"content": str(chunk)},
491+
"index": None,
492+
"finish_reason": None,
493+
}
494+
],
495+
}
496+
497+
# Send the payload as JSON
498+
data = json.dumps(payload, ensure_ascii=False)
499+
yield f"data: {data}\n\n"
500+
501+
395502
@app.post(
396503
"/v1/chat/completions",
397504
response_model=ResponseBody,
@@ -523,7 +630,12 @@ async def chat_completion(body: RequestBody, request: Request):
523630
)
524631
)
525632

526-
return StreamingResponse(streaming_handler)
633+
return StreamingResponse(
634+
_format_streaming_response(
635+
streaming_handler, model_name=config_ids[0] if config_ids else None
636+
),
637+
media_type="text/event-stream",
638+
)
527639
else:
528640
res = await llm_rails.generate_async(
529641
messages=messages, options=body.options, state=body.state

nemoguardrails/server/schemas/openai.py

Lines changed: 4 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -15,96 +15,16 @@
1515

1616
"""OpenAI API schema definitions for the NeMo Guardrails server."""
1717

18-
from typing import List, Optional, Union
18+
from typing import List, Optional
1919

20+
from openai.types.chat.chat_completion import ChatCompletion, Choice
21+
from openai.types.model import Model
2022
from pydantic import BaseModel, Field
2123

2224

23-
class OpenAIRequestFields(BaseModel):
24-
"""OpenAI API request fields that can be mixed into other request schemas."""
25-
26-
# Standard OpenAI completion parameters
27-
model: Optional[str] = Field(
28-
default=None,
29-
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
30-
)
31-
max_tokens: Optional[int] = Field(
32-
default=None,
33-
description="The maximum number of tokens to generate.",
34-
)
35-
temperature: Optional[float] = Field(
36-
default=None,
37-
description="Sampling temperature to use.",
38-
)
39-
top_p: Optional[float] = Field(
40-
default=None,
41-
description="Top-p sampling parameter.",
42-
)
43-
stop: Optional[Union[str, List[str]]] = Field(
44-
default=None,
45-
description="Stop sequences.",
46-
)
47-
presence_penalty: Optional[float] = Field(
48-
default=None,
49-
description="Presence penalty parameter.",
50-
)
51-
frequency_penalty: Optional[float] = Field(
52-
default=None,
53-
description="Frequency penalty parameter.",
54-
)
55-
function_call: Optional[dict] = Field(
56-
default=None,
57-
description="Function call parameter.",
58-
)
59-
logit_bias: Optional[dict] = Field(
60-
default=None,
61-
description="Logit bias parameter.",
62-
)
63-
log_probs: Optional[bool] = Field(
64-
default=None,
65-
description="Log probabilities parameter.",
66-
)
67-
68-
69-
class Choice(BaseModel):
70-
"""OpenAI API choice structure in chat completion responses."""
71-
72-
index: Optional[int] = Field(
73-
default=None, description="The index of the choice in the list of choices."
74-
)
75-
message: Optional[dict] = Field(
76-
default=None, description="The message of the choice"
77-
)
78-
logprobs: Optional[dict] = Field(
79-
default=None, description="The log probabilities of the choice"
80-
)
81-
finish_reason: Optional[str] = Field(
82-
default=None, description="The reason the model stopped generating tokens."
83-
)
84-
85-
86-
class ResponseBody(BaseModel):
25+
class ResponseBody(ChatCompletion):
8726
"""OpenAI API response body with NeMo-Guardrails extensions."""
8827

89-
# OpenAI API fields
90-
id: Optional[str] = Field(
91-
default=None, description="A unique identifier for the chat completion."
92-
)
93-
object: str = Field(
94-
default="chat.completion",
95-
description="The object type, which is always chat.completion",
96-
)
97-
created: Optional[int] = Field(
98-
default=None,
99-
description="The Unix timestamp (in seconds) of when the chat completion was created.",
100-
)
101-
model: Optional[str] = Field(
102-
default=None, description="The model used for the chat completion."
103-
)
104-
choices: Optional[List[Choice]] = Field(
105-
default=None, description="A list of chat completion choices."
106-
)
107-
# NeMo-Guardrails specific fields for backward compatibility
10828
state: Optional[dict] = Field(
10929
default=None, description="State object for continuing the conversation."
11030
)
@@ -117,23 +37,6 @@ class ResponseBody(BaseModel):
11737
log: Optional[dict] = Field(default=None, description="Generation log data.")
11838

11939

120-
class Model(BaseModel):
121-
"""OpenAI API model representation."""
122-
123-
id: str = Field(
124-
description="The model identifier, which can be referenced in the API endpoints."
125-
)
126-
object: str = Field(
127-
default="model", description="The object type, which is always 'model'."
128-
)
129-
created: int = Field(
130-
description="The Unix timestamp (in seconds) of when the model was created."
131-
)
132-
owned_by: str = Field(
133-
default="nemo-guardrails", description="The organization that owns the model."
134-
)
135-
136-
13740
class ModelsResponse(BaseModel):
13841
"""OpenAI API models list response."""
13942

nemoguardrails/streaming.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,18 +174,39 @@ async def __anext__(self):
174174

175175
async def _process(
176176
self,
177-
chunk: Union[str, object],
177+
chunk: Union[str, dict, object],
178178
generation_info: Optional[Dict[str, Any]] = None,
179179
):
180-
"""Process a chunk of text.
180+
"""Process a chunk of text or dict.
181181
182182
If we're in buffering mode, record the text.
183183
Otherwise, update the full completion, check for stop tokens, and enqueue the chunk.
184+
Dict chunks bypass completion tracking and go directly to the queue.
184185
"""
185186

186187
if self.include_generation_metadata and generation_info:
187188
self.current_generation_info = generation_info
188189

190+
# Dict chunks bypass buffering and completion tracking
191+
if isinstance(chunk, dict):
192+
if self.pipe_to:
193+
asyncio.create_task(self.pipe_to.push_chunk(chunk))
194+
else:
195+
if self.include_generation_metadata:
196+
await self.queue.put(
197+
{
198+
"text": chunk,
199+
"generation_info": (
200+
self.current_generation_info.copy()
201+
if self.current_generation_info
202+
else {}
203+
),
204+
}
205+
)
206+
else:
207+
await self.queue.put(chunk)
208+
return
209+
189210
if self.enable_buffer:
190211
if chunk is not END_OF_STREAM:
191212
self.buffer += chunk if chunk is not None else ""
@@ -259,10 +280,28 @@ async def _process(
259280

260281
async def push_chunk(
261282
self,
262-
chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, None],
283+
chunk: Union[
284+
str,
285+
dict,
286+
GenerationChunk,
287+
AIMessageChunk,
288+
ChatGenerationChunk,
289+
None,
290+
object,
291+
],
263292
generation_info: Optional[Dict[str, Any]] = None,
264293
):
265-
"""Push a new chunk to the stream."""
294+
"""Push a new chunk to the stream.
295+
296+
Args:
297+
chunk: The chunk to push. Can be:
298+
- str: Plain text content
299+
- dict: Dictionary with fields like role, content, etc.
300+
- GenerationChunk/AIMessageChunk/ChatGenerationChunk: LangChain chunk types
301+
- None: Signals end of stream (converted to END_OF_STREAM)
302+
- object: END_OF_STREAM sentinel
303+
generation_info: Optional metadata about the generation
304+
"""
266305

267306
# if generation_info is not explicitly passed,
268307
# try to get it from the chunk itself if it's a GenerationChunk or ChatGenerationChunk
@@ -288,6 +327,9 @@ async def push_chunk(
288327
elif isinstance(chunk, str):
289328
# empty string is a valid chunk and should be processed normally
290329
pass
330+
elif isinstance(chunk, dict):
331+
# plain dict chunks are allowed (e.g., for OpenAI-compatible streaming)
332+
pass
291333
else:
292334
raise Exception(f"Unsupported chunk type: {chunk.__class__.__name__}")
293335

@@ -298,6 +340,11 @@ async def push_chunk(
298340
if self.include_generation_metadata and generation_info:
299341
self.current_generation_info = generation_info
300342

343+
# Dict chunks bypass prefix/suffix processing and go directly to _process
344+
if isinstance(chunk, dict):
345+
await self._process(chunk, generation_info)
346+
return
347+
301348
# Process prefix: accumulate until the expected prefix is received, then remove it.
302349
if self.prefix:
303350
if chunk is not None and chunk is not END_OF_STREAM:

0 commit comments

Comments
 (0)