@@ -31,12 +31,16 @@ class MyLLMOutput(LLMOutput):
3131
3232from __future__ import annotations
3333
34+ import asyncio
35+ import functools
36+ import inspect
3437import json
3538import logging
3639import textwrap
37- from typing import TYPE_CHECKING , TypeVar
40+ from collections .abc import Callable
41+ from typing import TYPE_CHECKING , Any , TypeVar
3842
39- from openai import AsyncOpenAI
43+ from openai import AsyncOpenAI , OpenAI
4044from pydantic import BaseModel , ConfigDict , Field
4145
4246from guardrails .registry import default_spec_registry
@@ -45,7 +49,13 @@ class MyLLMOutput(LLMOutput):
4549from guardrails .utils .output import OutputSchema
4650
4751if TYPE_CHECKING :
48- from openai import AsyncOpenAI
52+ from openai import AsyncAzureOpenAI , AzureOpenAI # type: ignore[unused-import]
53+ else :
54+ try :
55+ from openai import AsyncAzureOpenAI , AzureOpenAI # type: ignore
56+ except Exception : # pragma: no cover - optional dependency
57+ AsyncAzureOpenAI = object # type: ignore[assignment]
58+ AzureOpenAI = object # type: ignore[assignment]
4959
5060logger = logging .getLogger (__name__ )
5161
@@ -165,10 +175,46 @@ def _strip_json_code_fence(text: str) -> str:
165175 return candidate
166176
167177
178+ async def _invoke_openai_callable (
179+ method : Callable [..., Any ],
180+ / ,
181+ * args : Any ,
182+ ** kwargs : Any ,
183+ ) -> Any :
184+ """Invoke OpenAI SDK methods that may be sync or async."""
185+ if inspect .iscoroutinefunction (method ):
186+ return await method (* args , ** kwargs )
187+
188+ loop = asyncio .get_running_loop ()
189+ result = await loop .run_in_executor (
190+ None ,
191+ functools .partial (method , * args , ** kwargs ),
192+ )
193+ if inspect .isawaitable (result ):
194+ return await result
195+ return result
196+
197+
198+ async def _request_chat_completion (
199+ client : AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI ,
200+ * ,
201+ messages : list [dict [str , str ]],
202+ model : str ,
203+ response_format : dict [str , Any ],
204+ ) -> Any :
205+ """Invoke chat.completions.create on sync or async OpenAI clients."""
206+ return await _invoke_openai_callable (
207+ client .chat .completions .create ,
208+ messages = messages ,
209+ model = model ,
210+ response_format = response_format ,
211+ )
212+
213+
168214async def run_llm (
169215 text : str ,
170216 system_prompt : str ,
171- client : AsyncOpenAI ,
217+ client : AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI ,
172218 model : str ,
173219 output_model : type [LLMOutput ],
174220) -> LLMOutput :
@@ -180,7 +226,7 @@ async def run_llm(
180226 Args:
181227 text (str): Text to analyze.
182228 system_prompt (str): Prompt instructions for the LLM.
183- client (AsyncOpenAI): OpenAI client for LLM inference .
229+ client (AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI ): OpenAI client used for guardrails .
184230 model (str): Identifier for which LLM model to use.
185231 output_model (type[LLMOutput]): Model for parsing and validating the LLM's response.
186232
@@ -190,7 +236,8 @@ async def run_llm(
190236 full_prompt = _build_full_prompt (system_prompt )
191237
192238 try :
193- response = await client .chat .completions .create (
239+ response = await _request_chat_completion (
240+ client = client ,
194241 messages = [
195242 {"role" : "system" , "content" : full_prompt },
196243 {"role" : "user" , "content" : f"# Text\n \n { text } " },
0 commit comments