Skip to content

Commit 66b40d0

Browse files
authored
Remove redundant client type check (#39)
* Remove redundant client type check * Handle sync vs async and oai vs azure clients * Remove unused imports
1 parent aac8b8f commit 66b40d0

File tree

2 files changed

+121
-15
lines changed

2 files changed

+121
-15
lines changed

src/guardrails/checks/text/moderation.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from __future__ import annotations
2929

30+
import asyncio
3031
import logging
3132
from enum import Enum
3233
from functools import cache
@@ -130,11 +131,11 @@ def _get_moderation_client() -> AsyncOpenAI:
130131
return AsyncOpenAI()
131132

132133

133-
async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any:
134-
"""Call the OpenAI moderation API.
134+
async def _call_moderation_api_async(client: Any, data: str) -> Any:
135+
"""Call the OpenAI moderation API asynchronously.
135136
136137
Args:
137-
client: The OpenAI client to use.
138+
client: The async OpenAI or Azure OpenAI client to use.
138139
data: The text to analyze.
139140
140141
Returns:
@@ -146,6 +147,22 @@ async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any:
146147
)
147148

148149

150+
def _call_moderation_api_sync(client: Any, data: str) -> Any:
151+
"""Call the OpenAI moderation API synchronously.
152+
153+
Args:
154+
client: The sync OpenAI or Azure OpenAI client to use.
155+
data: The text to analyze.
156+
157+
Returns:
158+
The moderation API response.
159+
"""
160+
return client.moderations.create(
161+
model="omni-moderation-latest",
162+
input=data,
163+
)
164+
165+
149166
async def moderation(
150167
ctx: Any,
151168
data: str,
@@ -165,29 +182,32 @@ async def moderation(
165182
Returns:
166183
GuardrailResult: Indicates if tripwire was triggered, and details of flagged categories.
167184
"""
168-
client = None
169-
if ctx is not None:
170-
candidate = getattr(ctx, "guardrail_llm", None)
171-
if isinstance(candidate, AsyncOpenAI):
172-
client = candidate
185+
# Try context client first (if provided), fall back on 404
186+
client = getattr(ctx, "guardrail_llm", None) if ctx is not None else None
173187

174-
# Try the context client first, fall back if moderation endpoint doesn't exist
175188
if client is not None:
189+
# Determine if client is async or sync
190+
is_async = isinstance(client, AsyncOpenAI)
191+
176192
try:
177-
resp = await _call_moderation_api(client, data)
193+
if is_async:
194+
resp = await _call_moderation_api_async(client, data)
195+
else:
196+
# Sync client - run in thread pool to avoid blocking event loop
197+
resp = await asyncio.to_thread(_call_moderation_api_sync, client, data)
178198
except NotFoundError as e:
179-
# Moderation endpoint doesn't exist on this provider (e.g., third-party)
180-
# Fall back to the OpenAI client
199+
# Moderation endpoint doesn't exist (e.g., Azure, third-party)
200+
# Fall back to OpenAI client with OPENAI_API_KEY env var
181201
logger.debug(
182202
"Moderation endpoint not available on context client, falling back to OpenAI: %s",
183203
e,
184204
)
185205
client = _get_moderation_client()
186-
resp = await _call_moderation_api(client, data)
206+
resp = await _call_moderation_api_async(client, data)
187207
else:
188-
# No context client, use fallback
208+
# No context client - use fallback OpenAI client
189209
client = _get_moderation_client()
190-
resp = await _call_moderation_api(client, data)
210+
resp = await _call_moderation_api_async(client, data)
191211
results = resp.results or []
192212
if not results:
193213
return GuardrailResult(

tests/unit/checks/test_moderation.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,89 @@ async def raise_not_found(**_: Any) -> Any:
134134
# Verify the fallback client was used (not the third-party one)
135135
assert fallback_used is True # noqa: S101
136136
assert result.tripwire_triggered is False # noqa: S101
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_moderation_uses_sync_context_client() -> None:
141+
"""Moderation should support synchronous OpenAI clients from context."""
142+
from openai import OpenAI
143+
144+
# Track whether sync context client was used
145+
sync_client_used = False
146+
147+
def track_sync_create(**_: Any) -> Any:
148+
nonlocal sync_client_used
149+
sync_client_used = True
150+
151+
class _Result:
152+
def model_dump(self) -> dict[str, Any]:
153+
return {"categories": {"hate": False, "violence": False}}
154+
155+
return SimpleNamespace(results=[_Result()])
156+
157+
# Create a sync context client
158+
sync_client = OpenAI(api_key="test-sync-key", base_url="https://api.openai.com/v1")
159+
sync_client.moderations = SimpleNamespace(create=track_sync_create) # type: ignore[assignment]
160+
161+
ctx = SimpleNamespace(guardrail_llm=sync_client)
162+
163+
cfg = ModerationCfg(categories=[Category.HATE, Category.VIOLENCE])
164+
result = await moderation(ctx, "test text", cfg)
165+
166+
# Verify the sync context client was used (via asyncio.to_thread)
167+
assert sync_client_used is True # noqa: S101
168+
assert result.tripwire_triggered is False # noqa: S101
169+
170+
171+
@pytest.mark.asyncio
172+
async def test_moderation_falls_back_for_azure_clients(monkeypatch: pytest.MonkeyPatch) -> None:
173+
"""Moderation should fall back to OpenAI client for Azure clients (no moderation endpoint)."""
174+
try:
175+
from openai import AsyncAzureOpenAI, NotFoundError
176+
except ImportError:
177+
pytest.skip("Azure OpenAI not available")
178+
179+
# Track whether fallback was used
180+
fallback_used = False
181+
182+
async def track_fallback_create(**_: Any) -> Any:
183+
nonlocal fallback_used
184+
fallback_used = True
185+
186+
class _Result:
187+
def model_dump(self) -> dict[str, Any]:
188+
return {"categories": {"hate": False, "violence": False}}
189+
190+
return SimpleNamespace(results=[_Result()])
191+
192+
# Mock the fallback client
193+
fallback_client = SimpleNamespace(moderations=SimpleNamespace(create=track_fallback_create))
194+
monkeypatch.setattr("guardrails.checks.text.moderation._get_moderation_client", lambda: fallback_client)
195+
196+
# Create a mock httpx.Response for NotFoundError
197+
mock_response = SimpleNamespace(
198+
status_code=404,
199+
headers={},
200+
text="404 page not found",
201+
json=lambda: {"error": {"message": "Not found", "type": "invalid_request_error"}},
202+
)
203+
204+
# Create an Azure context client that raises NotFoundError for moderation
205+
async def raise_not_found(**_: Any) -> Any:
206+
raise NotFoundError("404 page not found", response=mock_response, body=None) # type: ignore[arg-type]
207+
208+
azure_client = AsyncAzureOpenAI(
209+
api_key="test-azure-key",
210+
api_version="2024-02-01",
211+
azure_endpoint="https://test.openai.azure.com",
212+
)
213+
azure_client.moderations = SimpleNamespace(create=raise_not_found) # type: ignore[assignment]
214+
215+
ctx = SimpleNamespace(guardrail_llm=azure_client)
216+
217+
cfg = ModerationCfg(categories=[Category.HATE, Category.VIOLENCE])
218+
result = await moderation(ctx, "test text", cfg)
219+
220+
# Verify the fallback client was used (not the Azure one)
221+
assert fallback_used is True # noqa: S101
222+
assert result.tripwire_triggered is False # noqa: S101

0 commit comments

Comments
 (0)