Skip to content

Commit 515bd41

Browse files
authored
Handle sync guardrail calls to avoid awaitable error (#21)
1 parent a251298 commit 515bd41

File tree

5 files changed

+110
-15
lines changed

5 files changed

+110
-15
lines changed

src/guardrails/checks/text/hallucination_detection.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,7 @@
5252
from guardrails.spec import GuardrailSpecMetadata
5353
from guardrails.types import GuardrailLLMContextProto, GuardrailResult
5454

55-
from .llm_base import (
56-
LLMConfig,
57-
LLMOutput,
58-
)
55+
from .llm_base import LLMConfig, LLMOutput, _invoke_openai_callable
5956

6057
logger = logging.getLogger(__name__)
6158

@@ -210,9 +207,10 @@ async def hallucination_detection(
210207
validation_query = f"{VALIDATION_PROMPT}\n\nText to validate:\n{candidate}"
211208

212209
# Use the Responses API with file search and structured output
213-
response = await ctx.guardrail_llm.responses.parse(
214-
model=config.model,
210+
response = await _invoke_openai_callable(
211+
ctx.guardrail_llm.responses.parse,
215212
input=validation_query,
213+
model=config.model,
216214
text_format=HallucinationDetectionOutput,
217215
tools=[{"type": "file_search", "vector_store_ids": [config.knowledge_source]}],
218216
)

src/guardrails/checks/text/llm_base.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ class MyLLMOutput(LLMOutput):
3131

3232
from __future__ import annotations
3333

34+
import asyncio
35+
import functools
36+
import inspect
3437
import json
3538
import logging
3639
import textwrap
37-
from typing import TYPE_CHECKING, TypeVar
40+
from collections.abc import Callable
41+
from typing import TYPE_CHECKING, Any, TypeVar
3842

39-
from openai import AsyncOpenAI
43+
from openai import AsyncOpenAI, OpenAI
4044
from pydantic import BaseModel, ConfigDict, Field
4145

4246
from guardrails.registry import default_spec_registry
@@ -45,7 +49,13 @@ class MyLLMOutput(LLMOutput):
4549
from guardrails.utils.output import OutputSchema
4650

4751
if TYPE_CHECKING:
48-
from openai import AsyncOpenAI
52+
from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore[unused-import]
53+
else:
54+
try:
55+
from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore
56+
except Exception: # pragma: no cover - optional dependency
57+
AsyncAzureOpenAI = object # type: ignore[assignment]
58+
AzureOpenAI = object # type: ignore[assignment]
4959

5060
logger = logging.getLogger(__name__)
5161

@@ -165,10 +175,46 @@ def _strip_json_code_fence(text: str) -> str:
165175
return candidate
166176

167177

178+
async def _invoke_openai_callable(
179+
method: Callable[..., Any],
180+
/,
181+
*args: Any,
182+
**kwargs: Any,
183+
) -> Any:
184+
"""Invoke OpenAI SDK methods that may be sync or async."""
185+
if inspect.iscoroutinefunction(method):
186+
return await method(*args, **kwargs)
187+
188+
loop = asyncio.get_running_loop()
189+
result = await loop.run_in_executor(
190+
None,
191+
functools.partial(method, *args, **kwargs),
192+
)
193+
if inspect.isawaitable(result):
194+
return await result
195+
return result
196+
197+
198+
async def _request_chat_completion(
199+
client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI,
200+
*,
201+
messages: list[dict[str, str]],
202+
model: str,
203+
response_format: dict[str, Any],
204+
) -> Any:
205+
"""Invoke chat.completions.create on sync or async OpenAI clients."""
206+
return await _invoke_openai_callable(
207+
client.chat.completions.create,
208+
messages=messages,
209+
model=model,
210+
response_format=response_format,
211+
)
212+
213+
168214
async def run_llm(
169215
text: str,
170216
system_prompt: str,
171-
client: AsyncOpenAI,
217+
client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI,
172218
model: str,
173219
output_model: type[LLMOutput],
174220
) -> LLMOutput:
@@ -180,7 +226,7 @@ async def run_llm(
180226
Args:
181227
text (str): Text to analyze.
182228
system_prompt (str): Prompt instructions for the LLM.
183-
client (AsyncOpenAI): OpenAI client for LLM inference.
229+
client (AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI): OpenAI client used for guardrails.
184230
model (str): Identifier for which LLM model to use.
185231
output_model (type[LLMOutput]): Model for parsing and validating the LLM's response.
186232
@@ -190,7 +236,8 @@ async def run_llm(
190236
full_prompt = _build_full_prompt(system_prompt)
191237

192238
try:
193-
response = await client.chat.completions.create(
239+
response = await _request_chat_completion(
240+
client=client,
194241
messages=[
195242
{"role": "system", "content": full_prompt},
196243
{"role": "user", "content": f"# Text\n\n{text}"},

src/guardrails/checks/text/prompt_injection_detection.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from guardrails.spec import GuardrailSpecMetadata
3737
from guardrails.types import GuardrailLLMContextProto, GuardrailResult
3838

39-
from .llm_base import LLMConfig, LLMOutput
39+
from .llm_base import LLMConfig, LLMOutput, _invoke_openai_callable
4040

4141
__all__ = ["prompt_injection_detection", "PromptInjectionDetectionOutput"]
4242

@@ -373,9 +373,10 @@ def _create_skip_result(
373373

374374
async def _call_prompt_injection_detection_llm(ctx: GuardrailLLMContextProto, prompt: str, config: LLMConfig) -> PromptInjectionDetectionOutput:
375375
"""Call LLM for prompt injection detection analysis."""
376-
parsed_response = await ctx.guardrail_llm.responses.parse(
377-
model=config.model,
376+
parsed_response = await _invoke_openai_callable(
377+
ctx.guardrail_llm.responses.parse,
378378
input=prompt,
379+
model=config.model,
379380
text_format=PromptInjectionDetectionOutput,
380381
)
381382
return parsed_response.output_parsed

tests/unit/checks/test_llm_base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ def __init__(self, content: str | None) -> None:
3434
self.chat = SimpleNamespace(completions=_FakeCompletions(content))
3535

3636

37+
class _FakeSyncCompletions:
38+
def __init__(self, content: str | None) -> None:
39+
self._content = content
40+
41+
def create(self, **kwargs: Any) -> Any:
42+
_ = kwargs
43+
return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self._content))])
44+
45+
46+
class _FakeSyncClient:
47+
def __init__(self, content: str | None) -> None:
48+
self.chat = SimpleNamespace(completions=_FakeSyncCompletions(content))
49+
50+
3751
def test_strip_json_code_fence_removes_wrapping() -> None:
3852
"""Valid JSON code fences should be removed."""
3953
fenced = """```json
@@ -64,6 +78,23 @@ async def test_run_llm_returns_valid_output() -> None:
6478
assert result.flagged is True and result.confidence == 0.9 # noqa: S101
6579

6680

81+
@pytest.mark.asyncio
82+
async def test_run_llm_supports_sync_clients() -> None:
83+
"""run_llm should invoke synchronous clients without awaiting them."""
84+
client = _FakeSyncClient('{"flagged": false, "confidence": 0.25}')
85+
86+
result = await run_llm(
87+
text="General text",
88+
system_prompt="Assess text.",
89+
client=client, # type: ignore[arg-type]
90+
model="gpt-test",
91+
output_model=LLMOutput,
92+
)
93+
94+
assert isinstance(result, LLMOutput) # noqa: S101
95+
assert result.flagged is False and result.confidence == 0.25 # noqa: S101
96+
97+
6798
@pytest.mark.asyncio
6899
async def test_run_llm_handles_content_filter_error(monkeypatch: pytest.MonkeyPatch) -> None:
69100
"""Content filter errors should return LLMErrorOutput with flagged=True."""

tests/unit/checks/test_prompt_injection_detection.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,21 @@ async def failing_llm(*_args: Any, **_kwargs: Any) -> PromptInjectionDetectionOu
147147

148148
assert result.tripwire_triggered is False # noqa: S101
149149
assert "Error during prompt injection detection check" in result.info["observation"] # noqa: S101
150+
151+
152+
@pytest.mark.asyncio
153+
async def test_prompt_injection_detection_llm_supports_sync_responses() -> None:
154+
"""Underlying responses.parse may be synchronous for some clients."""
155+
analysis = PromptInjectionDetectionOutput(flagged=True, confidence=0.4, observation="Action summary")
156+
157+
class _SyncResponses:
158+
def parse(self, **kwargs: Any) -> Any:
159+
_ = kwargs
160+
return SimpleNamespace(output_parsed=analysis)
161+
162+
context = SimpleNamespace(guardrail_llm=SimpleNamespace(responses=_SyncResponses()))
163+
config = LLMConfig(model="gpt-test", confidence_threshold=0.5)
164+
165+
parsed = await pid_module._call_prompt_injection_detection_llm(context, "prompt", config)
166+
167+
assert parsed is analysis # noqa: S101

0 commit comments

Comments
 (0)