diff --git a/docs/user-guides/community/openai.md b/docs/user-guides/community/openai.md new file mode 100644 index 000000000..2150e4be8 --- /dev/null +++ b/docs/user-guides/community/openai.md @@ -0,0 +1,16 @@ +## OpenAI API Compatibility for NeMo Guardrails + +NeMo Guardrails provides server-side compatibility with OpenAI API endpoints, enabling applications that use OpenAI clients to seamlessly integrate with NeMo Guardrails for adding guardrails to LLM interactions. Point your OpenAI client to `http://localhost:8000` (or your server URL) and use the standard `/v1/chat/completions` endpoint. + +## Feature Support Matrix + +The following table outlines which OpenAI API features are currently supported when using NeMo Guardrails: + +| Feature | Status | Notes | +| :------ | :----: | :---- | +| **Basic Chat Completion** | ✔ Supported | Full support for standard chat completions with guardrails applied | +| **Streaming Responses** | ✔ Supported | Server-Sent Events (SSE) streaming with `stream=true` | +| **Multimodal Input** | ✖ Unsupported | Support for text and image inputs (vision models) with guardrails but not yet OpenAI compatible | +| **Function Calling** | ✖ Unsupported | Not yet implemented; guardrails need structured output support | +| **Tools** | ✖ Unsupported | Related to function calling; requires action flow integration | +| **Response Format (JSON Mode)** | ✖ Unsupported | Structured output with guardrails requires additional validation logic | diff --git a/nemoguardrails/colang/v2_x/runtime/runtime.py b/nemoguardrails/colang/v2_x/runtime/runtime.py index 9b17a7e94..be3263565 100644 --- a/nemoguardrails/colang/v2_x/runtime/runtime.py +++ b/nemoguardrails/colang/v2_x/runtime/runtime.py @@ -33,6 +33,7 @@ ColangSyntaxError, ) from nemoguardrails.colang.v2_x.runtime.flows import Event, FlowStatus +from nemoguardrails.colang.v2_x.runtime.serialization import json_to_state from nemoguardrails.colang.v2_x.runtime.statemachine import ( FlowConfig, InternalEvent, @@ -439,10 +440,13 @@ async def process_events( ) initialize_state(state) elif isinstance(state, dict): - # TODO: Implement dict to State conversion - raise NotImplementedError() - # if isinstance(state, dict): - # state = State.from_dict(state) + # Convert dict to State object + if state.get("version") == "2.x" and "state" in state: + # Handle the serialized state format from API calls + state = json_to_state(state["state"]) + else: + # TODO: Implement other dict to State conversion formats if needed + raise NotImplementedError("Unsupported state dict format") assert isinstance(state, State) assert state.main_flow_state is not None diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 6769dec1e..41a30a950 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import contextvars import importlib.util @@ -20,23 +21,24 @@ import os.path import re import time +import uuid import warnings from contextlib import asynccontextmanager -from typing import Any, Callable, List, Optional +from typing import Any, AsyncIterator, Callable, List, Optional, Union from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.model import Model from pydantic import BaseModel, Field, root_validator, validator from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from nemoguardrails import LLMRails, RailsConfig, utils -from nemoguardrails.rails.llm.options import ( - GenerationLog, - GenerationOptions, - GenerationResponse, -) +from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse from nemoguardrails.server.datastore.datastore import DataStore +from nemoguardrails.server.schemas.openai import ModelsResponse, ResponseBody from nemoguardrails.streaming import StreamingHandler logging.basicConfig(level=logging.INFO) @@ -197,10 +199,7 @@ class RequestBody(BaseModel): ) config_ids: Optional[List[str]] = Field( default=None, - description="The list of configuration ids to be used. " - "If set, the configurations will be combined.", - # alias="guardrails", - validate_default=True, + description="The ids of the configurations to be used. If not set, the default configuration will be used.", ) thread_id: Optional[str] = Field( default=None, @@ -229,10 +228,53 @@ class RequestBody(BaseModel): default=None, description="A state object that should be used to continue the interaction.", ) + # Standard OpenAI completion parameters + model: Optional[str] = Field( + default=None, + description="The model to use for chat completion. Maps to config_id for backward compatibility.", + ) + max_tokens: Optional[int] = Field( + default=None, + description="The maximum number of tokens to generate.", + ) + temperature: Optional[float] = Field( + default=None, + description="Sampling temperature to use.", + ) + top_p: Optional[float] = Field( + default=None, + description="Top-p sampling parameter.", + ) + stop: Optional[str] = Field( + default=None, + description="Stop sequences.", + ) + presence_penalty: Optional[float] = Field( + default=None, + description="Presence penalty parameter.", + ) + frequency_penalty: Optional[float] = Field( + default=None, + description="Frequency penalty parameter.", + ) + function_call: Optional[dict] = Field( + default=None, + description="Function call parameter.", + ) + logit_bias: Optional[dict] = Field( + default=None, + description="Logit bias parameter.", + ) + log_probs: Optional[bool] = Field( + default=None, + description="Log probabilities parameter.", + ) @root_validator(pre=True) def ensure_config_id(cls, data: Any) -> Any: if isinstance(data, dict): + if data.get("model") is not None and data.get("config_id") is None: + data["config_id"] = data["model"] if data.get("config_id") is not None and data.get("config_ids") is not None: raise ValueError( "Only one of config_id or config_ids should be specified" @@ -253,25 +295,44 @@ def ensure_config_ids(cls, v, values): return v -class ResponseBody(BaseModel): - messages: Optional[List[dict]] = Field( - default=None, description="The new messages in the conversation" - ) - llm_output: Optional[dict] = Field( - default=None, - description="Contains any additional output coming from the LLM.", - ) - output_data: Optional[dict] = Field( - default=None, - description="The output data, i.e. a dict with the values corresponding to the `output_vars`.", - ) - log: Optional[GenerationLog] = Field( - default=None, description="Additional logging information." - ) - state: Optional[dict] = Field( - default=None, - description="A state object that should be used to continue the interaction in the future.", - ) +@app.get( + "/v1/models", + response_model=ModelsResponse, + summary="List available models", + description="Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format.", +) +async def get_models(): + """Returns the list of available models (guardrails configurations) in OpenAI-compatible format.""" + + # Use the same logic as get_rails_configs to find available configurations + if app.single_config_mode: + config_ids = [app.single_config_id] if app.single_config_id else [] + else: + config_ids = [ + f + for f in os.listdir(app.rails_config_path) + if os.path.isdir(os.path.join(app.rails_config_path, f)) + and f[0] != "." + and f[0] != "_" + # Filter out all the configs for which there is no `config.yml` file. + and ( + os.path.exists(os.path.join(app.rails_config_path, f, "config.yml")) + or os.path.exists(os.path.join(app.rails_config_path, f, "config.yaml")) + ) + ] + + # Convert configurations to OpenAI model format + models = [] + for config_id in config_ids: + model = Model( + id=config_id, + object="model", + created=int(time.time()), # Use current time as created timestamp + owned_by="nemo-guardrails", + ) + models.append(model) + + return ModelsResponse(data=models) @app.get( @@ -366,6 +427,73 @@ def _get_rails(config_ids: List[str]) -> LLMRails: return llm_rails +async def _format_streaming_response( + streaming_handler: StreamingHandler, model_name: Optional[str] +) -> AsyncIterator[str]: + while True: + try: + chunk = await streaming_handler.__anext__() + except StopAsyncIteration: + # When the stream ends, yield the [DONE] message + yield "data: [DONE]\n\n" + break + + # Determine the payload format based on chunk type + if isinstance(chunk, dict): + # If chunk is a dict, wrap it in OpenAI chunk format with delta + payload = { + "id": None, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "delta": chunk, + "index": 0, + "finish_reason": None, + } + ], + } + elif isinstance(chunk, str): + try: + # Try parsing as JSON - if it parses, it might be a pre-formed payload + payload = json.loads(chunk) + except Exception: + # treat as plain text content token + payload = { + "id": None, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "delta": {"content": chunk}, + "index": 0, + "finish_reason": None, + } + ], + } + else: + # For any other type, treat as plain content + payload = { + "id": None, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "delta": {"content": str(chunk)}, + "index": 0, + "finish_reason": None, + } + ], + } + + # Send the payload as JSON + data = json.dumps(payload, ensure_ascii=False) + yield f"data: {data}\n\n" + + @app.post( "/v1/chat/completions", response_model=ResponseBody, @@ -401,13 +529,22 @@ async def chat_completion(body: RequestBody, request: Request): except ValueError as ex: log.exception(ex) return ResponseBody( - messages=[ - { - "role": "assistant", - "content": f"Could not load the {config_ids} guardrails configuration. " - f"An internal error has occurred.", - } - ] + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=config_ids[0] if config_ids else "unknown", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + content=f"Could not load the {config_ids} guardrails configuration. " + f"An internal error has occurred.", + role="assistant", + ), + finish_reason="stop", + logprobs=None, + ) + ], ) try: @@ -425,12 +562,21 @@ async def chat_completion(body: RequestBody, request: Request): # We make sure the `thread_id` meets the minimum complexity requirement. if len(body.thread_id) < 16: return ResponseBody( - messages=[ - { - "role": "assistant", - "content": "The `thread_id` must have a minimum length of 16 characters.", - } - ] + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model=config_ids[0] if config_ids else "unknown", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + content="The `thread_id` must have a minimum length of 16 characters.", + role="assistant", + ), + finish_reason="stop", + logprobs=None, + ) + ], ) # Fetch the existing thread messages. For easier management, we prepend @@ -441,6 +587,26 @@ async def chat_completion(body: RequestBody, request: Request): # And prepend them. messages = thread_messages + messages + generation_options = body.options + + # Initialize llm_params if not already set + if generation_options.llm_params is None: + generation_options.llm_params = {} + + # Set OpenAI-compatible parameters in llm_params + if body.max_tokens: + generation_options.llm_params["max_tokens"] = body.max_tokens + if body.temperature is not None: + generation_options.llm_params["temperature"] = body.temperature + if body.top_p is not None: + generation_options.llm_params["top_p"] = body.top_p + if body.stop: + generation_options.llm_params["stop"] = body.stop + if body.presence_penalty is not None: + generation_options.llm_params["presence_penalty"] = body.presence_penalty + if body.frequency_penalty is not None: + generation_options.llm_params["frequency_penalty"] = body.frequency_penalty + if ( body.stream and llm_rails.config.streaming_supported @@ -454,17 +620,20 @@ async def chat_completion(body: RequestBody, request: Request): llm_rails.generate_async( messages=messages, streaming_handler=streaming_handler, - options=body.options, + options=generation_options, state=body.state, ) ) - # TODO: Add support for thread_ids in streaming mode - - return StreamingResponse(streaming_handler) + return StreamingResponse( + _format_streaming_response( + streaming_handler, model_name=config_ids[0] if config_ids else None + ), + media_type="text/event-stream", + ) else: res = await llm_rails.generate_async( - messages=messages, options=body.options, state=body.state + messages=messages, options=generation_options, state=body.state ) if isinstance(res, GenerationResponse): @@ -483,21 +652,52 @@ async def chat_completion(body: RequestBody, request: Request): if body.thread_id and datastore is not None and datastore_key is not None: await datastore.set(datastore_key, json.dumps(messages + [bot_message])) - result = ResponseBody(messages=[bot_message]) + # Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions + response_kwargs = { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": config_ids[0] if config_ids else "unknown", + "choices": [ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content=bot_message["content"], + ), + finish_reason="stop", + logprobs=None, + ) + ], + } - # If we have additional GenerationResponse fields, we return as well + # If we have additional GenerationResponse fields, include them for backward compatibility if isinstance(res, GenerationResponse): - result.llm_output = res.llm_output - result.output_data = res.output_data - result.log = res.log - result.state = res.state + response_kwargs["llm_output"] = res.llm_output + response_kwargs["output_data"] = res.output_data + response_kwargs["log"] = res.log + response_kwargs["state"] = res.state - return result + return ResponseBody(**response_kwargs) except Exception as ex: log.exception(ex) return ResponseBody( - messages=[{"role": "assistant", "content": "Internal server error."}] + id=f"chatcmpl-{uuid.uuid4()}", + object="chat.completion", + created=int(time.time()), + model="unknown", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + content="Internal server error", + role="assistant", + ), + finish_reason="stop", + logprobs=None, + ) + ], ) diff --git a/nemoguardrails/server/schemas/openai.py b/nemoguardrails/server/schemas/openai.py new file mode 100644 index 000000000..a935a9c91 --- /dev/null +++ b/nemoguardrails/server/schemas/openai.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenAI API schema definitions for the NeMo Guardrails server.""" + +from typing import List, Optional + +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.model import Model +from pydantic import BaseModel, Field + + +class ResponseBody(ChatCompletion): + """OpenAI API response body with NeMo-Guardrails extensions.""" + + state: Optional[dict] = Field( + default=None, description="State object for continuing the conversation." + ) + llm_output: Optional[dict] = Field( + default=None, description="Additional LLM output data." + ) + output_data: Optional[dict] = Field( + default=None, description="Additional output data." + ) + log: Optional[dict] = Field(default=None, description="Generation log data.") + + +class ModelsResponse(BaseModel): + """OpenAI API models list response.""" + + object: str = Field( + default="list", description="The object type, which is always 'list'." + ) + data: List[Model] = Field(description="The list of models.") diff --git a/nemoguardrails/streaming.py b/nemoguardrails/streaming.py index 06ad3ee93..bc1b472a1 100644 --- a/nemoguardrails/streaming.py +++ b/nemoguardrails/streaming.py @@ -174,18 +174,39 @@ async def __anext__(self): async def _process( self, - chunk: Union[str, object], + chunk: Union[str, dict, object], generation_info: Optional[Dict[str, Any]] = None, ): - """Process a chunk of text. + """Process a chunk of text or dict. If we're in buffering mode, record the text. Otherwise, update the full completion, check for stop tokens, and enqueue the chunk. + Dict chunks bypass completion tracking and go directly to the queue. """ if self.include_generation_metadata and generation_info: self.current_generation_info = generation_info + # Dict chunks bypass buffering and completion tracking + if isinstance(chunk, dict): + if self.pipe_to: + asyncio.create_task(self.pipe_to.push_chunk(chunk)) + else: + if self.include_generation_metadata: + await self.queue.put( + { + "text": chunk, + "generation_info": ( + self.current_generation_info.copy() + if self.current_generation_info + else {} + ), + } + ) + else: + await self.queue.put(chunk) + return + if self.enable_buffer: if chunk is not END_OF_STREAM: self.buffer += chunk if chunk is not None else "" @@ -259,10 +280,28 @@ async def _process( async def push_chunk( self, - chunk: Union[str, GenerationChunk, AIMessageChunk, ChatGenerationChunk, None], + chunk: Union[ + str, + dict, + GenerationChunk, + AIMessageChunk, + ChatGenerationChunk, + None, + object, + ], generation_info: Optional[Dict[str, Any]] = None, ): - """Push a new chunk to the stream.""" + """Push a new chunk to the stream. + + Args: + chunk: The chunk to push. Can be: + - str: Plain text content + - dict: Dictionary with fields like role, content, etc. + - GenerationChunk/AIMessageChunk/ChatGenerationChunk: LangChain chunk types + - None: Signals end of stream (converted to END_OF_STREAM) + - object: END_OF_STREAM sentinel + generation_info: Optional metadata about the generation + """ # if generation_info is not explicitly passed, # try to get it from the chunk itself if it's a GenerationChunk or ChatGenerationChunk @@ -288,6 +327,9 @@ async def push_chunk( elif isinstance(chunk, str): # empty string is a valid chunk and should be processed normally pass + elif isinstance(chunk, dict): + # plain dict chunks are allowed (e.g., for OpenAI-compatible streaming) + pass else: raise Exception(f"Unsupported chunk type: {chunk.__class__.__name__}") @@ -298,6 +340,11 @@ async def push_chunk( if self.include_generation_metadata and generation_info: self.current_generation_info = generation_info + # Dict chunks bypass prefix/suffix processing and go directly to _process + if isinstance(chunk, dict): + await self._process(chunk, generation_info) + return + # Process prefix: accumulate until the expected prefix is received, then remove it. if self.prefix: if chunk is not None and chunk is not END_OF_STREAM: diff --git a/poetry.lock b/poetry.lock index 2222572f7..267d7b70c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "accessible-pygments" @@ -22,7 +22,7 @@ tests = ["hypothesis", "pytest"] name = "aiofiles" version = "24.1.0" description = "File support for asyncio." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, @@ -993,7 +993,7 @@ files = [ name = "distro" version = "1.9.0" description = "Distro - an OS platform information API" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, @@ -1421,6 +1421,8 @@ files = [ {file = "greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d"}, {file = "greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5"}, {file = "greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f"}, + {file = "greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7"}, + {file = "greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8"}, {file = "greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c"}, {file = "greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2"}, {file = "greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246"}, @@ -1430,6 +1432,8 @@ files = [ {file = "greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8"}, {file = "greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52"}, {file = "greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa"}, + {file = "greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c"}, + {file = "greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5"}, {file = "greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9"}, {file = "greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd"}, {file = "greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb"}, @@ -1439,6 +1443,8 @@ files = [ {file = "greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0"}, {file = "greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0"}, {file = "greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f"}, + {file = "greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0"}, + {file = "greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d"}, {file = "greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02"}, {file = "greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31"}, {file = "greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945"}, @@ -1448,6 +1454,8 @@ files = [ {file = "greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671"}, {file = "greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b"}, {file = "greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae"}, + {file = "greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b"}, + {file = "greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929"}, {file = "greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b"}, {file = "greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0"}, {file = "greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f"}, @@ -1455,6 +1463,8 @@ files = [ {file = "greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1"}, {file = "greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735"}, {file = "greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337"}, + {file = "greenlet-3.2.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2917bdf657f5859fbf3386b12d68ede4cf1f04c90c3a6bc1f013dd68a22e2269"}, + {file = "greenlet-3.2.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:015d48959d4add5d6c9f6c5210ee3803a830dce46356e3bc326d6776bde54681"}, {file = "greenlet-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:e37ab26028f12dbb0ff65f29a8d3d44a765c61e729647bf2ddfbbed621726f01"}, {file = "greenlet-3.2.4-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:b6a7c19cf0d2742d0809a4c05975db036fdff50cd294a93632d6a310bf9ac02c"}, {file = "greenlet-3.2.4-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:27890167f55d2387576d1f41d9487ef171849ea0359ce1510ca6e06c8bece11d"}, @@ -1464,6 +1474,8 @@ files = [ {file = "greenlet-3.2.4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9913f1a30e4526f432991f89ae263459b1c64d1608c0d22a5c79c287b3c70df"}, {file = "greenlet-3.2.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b90654e092f928f110e0007f572007c9727b5265f7632c2fa7415b4689351594"}, {file = "greenlet-3.2.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:81701fd84f26330f0d5f4944d4e92e61afe6319dcd9775e39396e39d7c3e5f98"}, + {file = "greenlet-3.2.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:28a3c6b7cd72a96f61b0e4b2a36f681025b60ae4779cc73c1535eb5f29560b10"}, + {file = "greenlet-3.2.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:52206cd642670b0b320a1fd1cbfd95bca0e043179c1d8a045f2c6109dfe973be"}, {file = "greenlet-3.2.4-cp39-cp39-win32.whl", hash = "sha256:65458b409c1ed459ea899e939f0e1cdb14f58dbc803f2f93c5eab5694d32671b"}, {file = "greenlet-3.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:d2e685ade4dafd447ede19c31277a224a239a0a1a4eca4e6390efedf20260cfb"}, {file = "greenlet-3.2.4.tar.gz", hash = "sha256:0dca0d95ff849f9a364385f36ab49f50065d76964944638be9691e1832e9f86d"}, @@ -1800,7 +1812,7 @@ i18n = ["Babel (>=2.7)"] name = "jiter" version = "0.10.0" description = "Fast iterable JSON parser." -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "jiter-0.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:cd2fb72b02478f06a900a5782de2ef47e0396b3e1f7d5aba30daeb1fce66f303"}, @@ -3070,7 +3082,7 @@ sympy = "*" name = "openai" version = "1.102.0" description = "The official Python library for the openai API" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "openai-1.102.0-py3-none-any.whl", hash = "sha256:d751a7e95e222b5325306362ad02a7aa96e1fab3ed05b5888ce1c7ca63451345"}, @@ -4212,13 +4224,13 @@ dev = ["build", "flake8", "mypy", "pytest", "twine"] [[package]] name = "pyright" -version = "1.1.405" +version = "1.1.407" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.405-py3-none-any.whl", hash = "sha256:a2cb13700b5508ce8e5d4546034cb7ea4aedb60215c6c33f56cec7f53996035a"}, - {file = "pyright-1.1.405.tar.gz", hash = "sha256:5c2a30e1037af27eb463a1cc0b9f6d65fec48478ccf092c1ac28385a15c55763"}, + {file = "pyright-1.1.407-py3-none-any.whl", hash = "sha256:6dd419f54fcc13f03b52285796d65e639786373f433e243f8b94cf93a7444d21"}, + {file = "pyright-1.1.407.tar.gz", hash = "sha256:099674dba5c10489832d4a4b2d302636152a9a42d317986c38474c76fe562262"}, ] [package.dependencies] @@ -6366,16 +6378,16 @@ files = [ cffi = ["cffi (>=1.17)"] [extras] -all = ["aiofiles", "google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] +all = ["google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] eval = ["numpy", "numpy", "numpy", "numpy", "streamlit", "tornado", "tqdm"] gcp = ["google-cloud-language"] jailbreak = ["yara-python"] nvidia = ["langchain-nvidia-ai-endpoints"] openai = ["langchain-openai"] sdd = ["presidio-analyzer", "presidio-anonymizer"] -tracing = ["aiofiles", "opentelemetry-api"] +tracing = ["opentelemetry-api"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.14" -content-hash = "64c8714671cb7f73952e4c8bfd53291a5e1ae13eac7c993286aca1409a13bf76" +content-hash = "8d456424d7a10f6e08c69755568b81cd8d2779bae98baa8d29f2be06098c3bf5" diff --git a/pyproject.toml b/pyproject.toml index 1d252a8bf..e782688e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,10 +71,11 @@ starlette = ">=0.49.1" typer = ">=0.8" uvicorn = ">=0.23" watchdog = ">=3.0.0," +aiofiles = ">=24.1.0" +openai = ">=1.0.0, <2.0.0" # tracing opentelemetry-api = { version = ">=1.27.0,<2.0.0", optional = true } -aiofiles = { version = ">=24.1.0", optional = true } # openai langchain-openai = { version = ">=0.1.0", optional = true } @@ -110,7 +111,7 @@ sdd = ["presidio-analyzer", "presidio-anonymizer"] eval = ["tqdm", "numpy", "streamlit", "tornado"] openai = ["langchain-openai"] gcp = ["google-cloud-language"] -tracing = ["opentelemetry-api", "aiofiles"] +tracing = ["opentelemetry-api"] nvidia = ["langchain-nvidia-ai-endpoints"] jailbreak = ["yara-python"] # Poetry does not support recursive dependencies, so we need to add all the dependencies here. @@ -125,7 +126,6 @@ all = [ "langchain-openai", "google-cloud-language", "opentelemetry-api", - "aiofiles", "langchain-nvidia-ai-endpoints", "yara-python", ] diff --git a/tests/test_api.py b/tests/test_api.py index 759af575f..c1ff4293c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -13,13 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import json import os import pytest from fastapi.testclient import TestClient from nemoguardrails.server import api -from nemoguardrails.server.api import RequestBody +from nemoguardrails.server.api import RequestBody, _format_streaming_response +from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler client = TestClient(api.app) @@ -43,6 +46,26 @@ def test_get(): assert len(result) > 0 +def test_get_models(): + """Test the OpenAI-compatible /v1/models endpoint.""" + response = client.get("/v1/models") + assert response.status_code == 200 + + result = response.json() + + # Check OpenAI models list format + assert result["object"] == "list" + assert "data" in result + assert len(result["data"]) > 0 + + # Check each model has the required OpenAI format + for model in result["data"]: + assert "id" in model + assert model["object"] == "model" + assert "created" in model + assert model["owned_by"] == "nemo-guardrails" + + @pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") def test_chat_completion(): response = client.post( @@ -59,8 +82,14 @@ def test_chat_completion(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] + # Check OpenAI-compatible response structure + assert res["object"] == "chat.completion" + assert "id" in res + assert "created" in res + assert "model" in res + assert len(res["choices"]) == 1 + assert res["choices"][0]["message"]["content"] + assert res["choices"][0]["message"]["role"] == "assistant" @pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") @@ -80,8 +109,14 @@ def test_chat_completion_with_default_configs(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] + # Check OpenAI-compatible response structure + assert res["object"] == "chat.completion" + assert "id" in res + assert "created" in res + assert "model" in res + assert len(res["choices"]) == 1 + assert res["choices"][0]["message"]["content"] + assert res["choices"][0]["message"]["role"] == "assistant" def test_request_body_validation(): @@ -117,6 +152,31 @@ def test_request_body_validation(): assert request_body.config_ids is None +def test_openai_model_field_mapping(): + """Test OpenAI-compatible model field mapping to config_id.""" + + # Test model field maps to config_id + data = { + "model": "test_model", + "messages": [{"role": "user", "content": "Hello"}], + } + request_body = RequestBody.model_validate(data) + assert request_body.model == "test_model" + assert request_body.config_id == "test_model" + assert request_body.config_ids == ["test_model"] + + # Test model and config_id both provided (config_id takes precedence) + data = { + "model": "test_model", + "config_id": "test_config", + "messages": [{"role": "user", "content": "Hello"}], + } + request_body = RequestBody.model_validate(data) + assert request_body.model == "test_model" + assert request_body.config_id == "test_config" + assert request_body.config_ids == ["test_config"] + + def test_request_body_state(): """Test RequestBody state handling.""" data = { @@ -146,3 +206,301 @@ def test_request_body_messages(): } request_body = RequestBody.model_validate(data) assert len(request_body.messages) == 1 + + +@pytest.mark.asyncio +async def test_openai_sse_format_basic_chunks(): + """Test basic string chunks are properly formatted as SSE events.""" + handler = StreamingHandler() + + # Collect yielded SSE messages + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push a couple of chunks and then signal completion + await handler.push_chunk("Hello ") + await handler.push_chunk("world") + await handler.push_chunk(END_OF_STREAM) + + # Wait for the collector task to finish + await task + + # We expect three messages: two data: {json}\n\n events and final data: [DONE]\n\n + assert len(collected) == 3 + # First two are JSON SSE events + evt1 = collected[0] + evt2 = collected[1] + done = collected[2] + + assert evt1.startswith("data: ") + j1 = json.loads(evt1[len("data: ") :].strip()) + assert j1["object"] == "chat.completion.chunk" + assert j1["choices"][0]["delta"]["content"] == "Hello " + + assert evt2.startswith("data: ") + j2 = json.loads(evt2[len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == "world" + + assert done == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_with_model_name(): + """Test that model name is properly included in the response.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name="gpt-4"): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Test") + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["model"] == "gpt-4" + assert j["choices"][0]["delta"]["content"] == "Test" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_with_dict_chunk(): + """Test that dict chunks with role and content are properly formatted.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push a dict chunk that includes role and content + await handler.push_chunk({"role": "assistant", "content": "Hi!"}) + await handler.push_chunk(None) + + await task + + # We expect two messages: one data chunk and final data: [DONE] + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["object"] == "chat.completion.chunk" + assert j["choices"][0]["delta"]["role"] == "assistant" + assert j["choices"][0]["delta"]["content"] == "Hi!" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_empty_string(): + """Test that empty strings are handled correctly.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("") + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["choices"][0]["delta"]["content"] == "" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_none_triggers_done(): + """Test that None (converted to END_OF_STREAM) triggers [DONE].""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Content") + await handler.push_chunk(None) # None converts to END_OF_STREAM + + await task + + assert len(collected) == 2 + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + assert j["choices"][0]["delta"]["content"] == "Content" + assert collected[1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_multiple_dict_chunks(): + """Test multiple dict chunks with different fields.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name="test-model"): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push multiple dict chunks + await handler.push_chunk({"role": "assistant"}) + await handler.push_chunk({"content": "Hello"}) + await handler.push_chunk({"content": " world"}) + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 4 + + # Check first chunk (role only) + j1 = json.loads(collected[0][len("data: ") :].strip()) + assert j1["choices"][0]["delta"]["role"] == "assistant" + assert "content" not in j1["choices"][0]["delta"] + + # Check second chunk (content only) + j2 = json.loads(collected[1][len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == "Hello" + + # Check third chunk (content only) + j3 = json.loads(collected[2][len("data: ") :].strip()) + assert j3["choices"][0]["delta"]["content"] == " world" + + # Check [DONE] message + assert collected[3] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_special_characters(): + """Test that special characters are properly escaped in JSON.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + # Push chunks with special characters + await handler.push_chunk("Line 1\nLine 2") + await handler.push_chunk('Quote: "test"') + await handler.push_chunk(END_OF_STREAM) + + await task + + assert len(collected) == 3 + + # Verify first chunk with newline + j1 = json.loads(collected[0][len("data: ") :].strip()) + assert j1["choices"][0]["delta"]["content"] == "Line 1\nLine 2" + + # Verify second chunk with quotes + j2 = json.loads(collected[1][len("data: ") :].strip()) + assert j2["choices"][0]["delta"]["content"] == 'Quote: "test"' + + assert collected[2] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_events(): + """Test that all events follow proper SSE format.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name=None): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Test") + await handler.push_chunk(END_OF_STREAM) + + await task + + # All events except [DONE] should be valid JSON with proper SSE format + for event in collected[:-1]: + assert event.startswith("data: ") + assert event.endswith("\n\n") + # Verify it's valid JSON + json_str = event[len("data: ") :].strip() + j = json.loads(json_str) + assert "object" in j + assert "choices" in j + assert isinstance(j["choices"], list) + assert len(j["choices"]) > 0 + + # Last event should be [DONE] + assert collected[-1] == "data: [DONE]\n\n" + + +@pytest.mark.asyncio +async def test_openai_sse_format_chunk_metadata(): + """Test that chunk metadata is properly formatted.""" + handler = StreamingHandler() + collected = [] + + async def collector(): + async for b in _format_streaming_response(handler, model_name="test-model"): + collected.append(b) + + task = asyncio.create_task(collector()) + + await handler.push_chunk("Test") + await handler.push_chunk(END_OF_STREAM) + + await task + + evt = collected[0] + j = json.loads(evt[len("data: ") :].strip()) + + # Verify all required fields are present + assert j["id"] is None # id can be None for chunks + assert j["object"] == "chat.completion.chunk" + assert isinstance(j["created"], int) + assert j["model"] == "test-model" + assert isinstance(j["choices"], list) + assert len(j["choices"]) == 1 + + choice = j["choices"][0] + assert "delta" in choice + assert choice["index"] == 0 + assert choice["finish_reason"] is None + + +@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +def test_chat_completion_with_streaming(): + response = client.post( + "/v1/chat/completions", + json={ + "config_id": "general", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + }, + ) + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/event-stream" + for chunk in response.iter_lines(): + assert chunk.startswith("data: ") + assert chunk.endswith("\n\n") + assert "data: [DONE]\n\n" in response.text diff --git a/tests/test_openai_integration.py b/tests/test_openai_integration.py new file mode 100644 index 000000000..735651a66 --- /dev/null +++ b/tests/test_openai_integration.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +from fastapi.testclient import TestClient +from openai import OpenAI + +from nemoguardrails.server import api + + +@pytest.fixture(scope="function", autouse=True) +def set_rails_config_path(): + """Set the rails config path to the test configs directory.""" + original_path = api.app.rails_config_path + api.app.rails_config_path = os.path.normpath( + os.path.join(os.path.dirname(__file__), "test_configs/simple_server") + ) + yield + + # Restore the original path and clear cache after the test + api.app.rails_config_path = original_path + api.llm_rails_instances.clear() + api.llm_rails_events_history_cache.clear() + + +@pytest.fixture(scope="function") +def test_client(): + """Create a FastAPI TestClient for the guardrails server.""" + return TestClient(api.app) + + +@pytest.fixture(scope="function") +def openai_client(test_client): + client = OpenAI( + api_key="dummy-key", + base_url="http://dummy-url/v1", + http_client=test_client, + ) + return client + + +def test_openai_client_list_models(openai_client): + models = openai_client.models.list() + + # Verify the response structure matches OpenAI's ModelList + assert models is not None + assert hasattr(models, "data") + assert len(models.data) > 0 + + # Check first model has required fields + model = models.data[0] + assert hasattr(model, "id") + assert hasattr(model, "object") + assert model.object == "model" + assert hasattr(model, "created") + assert hasattr(model, "owned_by") + assert model.owned_by == "nemo-guardrails" + + +def test_openai_client_chat_completion(openai_client): + response = openai_client.chat.completions.create( + model="config_1", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + # Verify response structure matches OpenAI's ChatCompletion object + assert response is not None + assert hasattr(response, "id") + assert response.id is not None + assert hasattr(response, "object") + assert response.object == "chat.completion" + assert hasattr(response, "created") + assert response.created > 0 + assert hasattr(response, "model") + assert response.model == "config_1" + + # Verify choices structure + assert hasattr(response, "choices") + assert len(response.choices) == 1 + choice = response.choices[0] + assert hasattr(choice, "index") + assert choice.index == 0 + assert hasattr(choice, "message") + assert hasattr(choice.message, "role") + assert choice.message.role == "assistant" + assert hasattr(choice.message, "content") + assert choice.message.content is not None + assert isinstance(choice.message.content, str) + assert len(choice.message.content) > 0 + assert hasattr(choice, "finish_reason") + assert choice.finish_reason == "stop" + + +def test_openai_client_chat_completion_parameterized(openai_client): + response = openai_client.chat.completions.create( + model="config_1", + messages=[{"role": "user", "content": "hi"}], + temperature=0.7, + max_tokens=100, + stream=False, + ) + + # Verify response exists + assert response is not None + assert response.choices[0].message.content is not None + + +def test_openai_client_chat_completion_input_rails(openai_client): + response = openai_client.chat.completions.create( + model="input_rails", + messages=[{"role": "user", "content": "Hello, how are you?"}], + stream=False, + ) + + # Verify response exists + assert response is not None + assert response.choices[0].message.content is not None + assert isinstance(response.choices[0].message.content, str) + + +@pytest.mark.skip(reason="Should only be run locally as it needs OpenAI key.") +def test_openai_client_chat_completion_streaming(openai_client): + stream = openai_client.chat.completions.create( + model="input_rails", + messages=[{"role": "user", "content": "Tell me a short joke."}], + stream=True, + ) + + chunks = list(stream) + assert len(chunks) > 0 + + # Verify at least one chunk has content + has_content = any( + hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content + for chunk in chunks + ) + assert has_content, "At least one chunk should contain content" + + +def test_openai_client_error_handling_invalid_model(openai_client): + response = openai_client.chat.completions.create( + model="nonexistent_config", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + # The server should return a response (not raise an exception) + assert response is not None + # The error should be in the content + assert ( + "Could not load" in response.choices[0].message.content + or "error" in response.choices[0].message.content.lower() + ) diff --git a/tests/test_server_calls_with_state.py b/tests/test_server_calls_with_state.py index 9560a9511..2157c720b 100644 --- a/tests/test_server_calls_with_state.py +++ b/tests/test_server_calls_with_state.py @@ -37,12 +37,15 @@ def _test_call(config_id): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] == "Hello!" + print(res) + assert len(res["choices"][0]["message"]) == 2 + assert res["choices"][0]["message"]["content"] == "Hello!" assert res.get("state") # When making a second call with the returned state, the conversations should continue # and we should get the "Hello again!" message. + # For Colang 2.x, we only send the new user message, not the conversation history + # since the state maintains the conversation context. response = client.post( "/v1/chat/completions", json={ @@ -57,7 +60,7 @@ def _test_call(config_id): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" def test_1(): diff --git a/tests/test_threads.py b/tests/test_threads.py index 4903e07bb..34f175dc6 100644 --- a/tests/test_threads.py +++ b/tests/test_threads.py @@ -53,8 +53,9 @@ def test_1(): ) assert response.status_code == 200 res = response.json() - assert len(res["messages"]) == 1 - assert res["messages"][0]["content"] == "Hello!" + assert "choices" in res + assert "message" in res["choices"][0] + assert res["choices"][0]["message"]["content"] == "Hello!" # When making a second call with the same thread_id, the conversations should continue # and we should get the "Hello again!" message. @@ -72,7 +73,7 @@ def test_1(): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!" @pytest.mark.parametrize( @@ -140,4 +141,4 @@ def test_with_redis(): }, ) res = response.json() - assert res["messages"][0]["content"] == "Hello again!" + assert res["choices"][0]["message"]["content"] == "Hello again!"