Skip to content

Commit 1bfd82b

Browse files
authored
Correctly passing API key to moderation (#36)
* extract call_moderation helper
1 parent 45e958c commit 1bfd82b

File tree

3 files changed

+128
-30
lines changed

3 files changed

+128
-30
lines changed

src/guardrails/checks/text/moderation.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from functools import cache
3333
from typing import Any
3434

35-
from openai import AsyncOpenAI
35+
from openai import AsyncOpenAI, NotFoundError
3636
from pydantic import BaseModel, ConfigDict, Field
3737

3838
from guardrails.registry import default_spec_registry
@@ -132,6 +132,22 @@ def _get_moderation_client() -> AsyncOpenAI:
132132
return AsyncOpenAI(**prepare_openai_kwargs({}))
133133

134134

135+
async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any:
136+
"""Call the OpenAI moderation API.
137+
138+
Args:
139+
client: The OpenAI client to use.
140+
data: The text to analyze.
141+
142+
Returns:
143+
The moderation API response.
144+
"""
145+
return await client.moderations.create(
146+
model="omni-moderation-latest",
147+
input=data,
148+
)
149+
150+
135151
async def moderation(
136152
ctx: Any,
137153
data: str,
@@ -151,36 +167,29 @@ async def moderation(
151167
Returns:
152168
GuardrailResult: Indicates if tripwire was triggered, and details of flagged categories.
153169
"""
154-
155-
# Prefer reusing an existing OpenAI client from context ONLY if it targets the
156-
# official OpenAI API. If it's any other provider (e.g., Ollama via base_url),
157-
# fall back to the default OpenAI moderation client.
158-
def _maybe_reuse_openai_client_from_ctx(context: Any) -> AsyncOpenAI | None:
170+
client = None
171+
if ctx is not None:
172+
candidate = getattr(ctx, "guardrail_llm", None)
173+
if isinstance(candidate, AsyncOpenAI):
174+
client = candidate
175+
176+
# Try the context client first, fall back if moderation endpoint doesn't exist
177+
if client is not None:
159178
try:
160-
candidate = getattr(context, "guardrail_llm", None)
161-
if not isinstance(candidate, AsyncOpenAI):
162-
return None
163-
164-
# Attempt to discover the effective base URL in a best-effort way
165-
base_url = getattr(candidate, "base_url", None)
166-
if base_url is None:
167-
inner = getattr(candidate, "_client", None)
168-
base_url = getattr(inner, "base_url", None) or getattr(inner, "_base_url", None)
169-
170-
# Reuse only when clearly the official OpenAI endpoint
171-
if base_url is None:
172-
return candidate
173-
if isinstance(base_url, str) and "api.openai.com" in base_url:
174-
return candidate
175-
return None
176-
except Exception:
177-
return None
178-
179-
client = _maybe_reuse_openai_client_from_ctx(ctx) or _get_moderation_client()
180-
resp = await client.moderations.create(
181-
model="omni-moderation-latest",
182-
input=data,
183-
)
179+
resp = await _call_moderation_api(client, data)
180+
except NotFoundError as e:
181+
# Moderation endpoint doesn't exist on this provider (e.g., third-party)
182+
# Fall back to the OpenAI client
183+
logger.debug(
184+
"Moderation endpoint not available on context client, falling back to OpenAI: %s",
185+
e,
186+
)
187+
client = _get_moderation_client()
188+
resp = await _call_moderation_api(client, data)
189+
else:
190+
# No context client, use fallback
191+
client = _get_moderation_client()
192+
resp = await _call_moderation_api(client, data)
184193
results = resp.results or []
185194
if not results:
186195
return GuardrailResult(

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,18 @@ class APITimeoutError(Exception):
6767
"""Stub API timeout error."""
6868

6969

70+
class NotFoundError(Exception):
71+
"""Stub 404 not found error."""
72+
73+
def __init__(self, message: str, *, response: Any = None, body: Any = None) -> None:
74+
"""Initialize NotFoundError with OpenAI-compatible signature."""
75+
super().__init__(message)
76+
self.response = response
77+
self.body = body
78+
79+
7080
_STUB_OPENAI_MODULE.APITimeoutError = APITimeoutError
81+
_STUB_OPENAI_MODULE.NotFoundError = NotFoundError
7182

7283
_OPENAI_TYPES_MODULE = types.ModuleType("openai.types")
7384
_OPENAI_TYPES_MODULE.Completion = _DummyResponse

tests/unit/checks/test_moderation.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,81 @@ async def create_empty(**_: Any) -> Any:
5656

5757
assert result.tripwire_triggered is False # noqa: S101
5858
assert result.info["error"] == "No moderation results returned" # noqa: S101
59+
60+
61+
@pytest.mark.asyncio
62+
async def test_moderation_uses_context_client() -> None:
63+
"""Moderation should use the client from context when available."""
64+
from openai import AsyncOpenAI
65+
66+
# Track whether context client was used
67+
context_client_used = False
68+
69+
async def track_create(**_: Any) -> Any:
70+
nonlocal context_client_used
71+
context_client_used = True
72+
73+
class _Result:
74+
def model_dump(self) -> dict[str, Any]:
75+
return {"categories": {"hate": False, "violence": False}}
76+
77+
return SimpleNamespace(results=[_Result()])
78+
79+
# Create a context with a guardrail_llm client
80+
context_client = AsyncOpenAI(api_key="test-context-key", base_url="https://api.openai.com/v1")
81+
context_client.moderations = SimpleNamespace(create=track_create) # type: ignore[assignment]
82+
83+
ctx = SimpleNamespace(guardrail_llm=context_client)
84+
85+
cfg = ModerationCfg(categories=[Category.HATE])
86+
result = await moderation(ctx, "test text", cfg)
87+
88+
# Verify the context client was used
89+
assert context_client_used is True # noqa: S101
90+
assert result.tripwire_triggered is False # noqa: S101
91+
92+
93+
@pytest.mark.asyncio
94+
async def test_moderation_falls_back_for_third_party_provider(monkeypatch: pytest.MonkeyPatch) -> None:
95+
"""Moderation should fall back to environment client for third-party providers."""
96+
from openai import AsyncOpenAI, NotFoundError
97+
98+
# Create fallback client that tracks usage
99+
fallback_used = False
100+
101+
async def track_fallback_create(**_: Any) -> Any:
102+
nonlocal fallback_used
103+
fallback_used = True
104+
105+
class _Result:
106+
def model_dump(self) -> dict[str, Any]:
107+
return {"categories": {"hate": False}}
108+
109+
return SimpleNamespace(results=[_Result()])
110+
111+
fallback_client = SimpleNamespace(moderations=SimpleNamespace(create=track_fallback_create))
112+
monkeypatch.setattr("guardrails.checks.text.moderation._get_moderation_client", lambda: fallback_client)
113+
114+
# Create a mock httpx.Response for NotFoundError
115+
mock_response = SimpleNamespace(
116+
status_code=404,
117+
headers={},
118+
text="404 page not found",
119+
json=lambda: {"error": {"message": "Not found", "type": "invalid_request_error"}},
120+
)
121+
122+
# Create a context client that simulates a third-party provider
123+
# When moderation is called, it should raise NotFoundError
124+
async def raise_not_found(**_: Any) -> Any:
125+
raise NotFoundError("404 page not found", response=mock_response, body=None) # type: ignore[arg-type]
126+
127+
third_party_client = AsyncOpenAI(api_key="third-party-key", base_url="https://localhost:8080/v1")
128+
third_party_client.moderations = SimpleNamespace(create=raise_not_found) # type: ignore[assignment]
129+
ctx = SimpleNamespace(guardrail_llm=third_party_client)
130+
131+
cfg = ModerationCfg(categories=[Category.HATE])
132+
result = await moderation(ctx, "test text", cfg)
133+
134+
# Verify the fallback client was used (not the third-party one)
135+
assert fallback_used is True # noqa: S101
136+
assert result.tripwire_triggered is False # noqa: S101

0 commit comments

Comments
 (0)