@@ -134,3 +134,89 @@ async def raise_not_found(**_: Any) -> Any:
134134 # Verify the fallback client was used (not the third-party one)
135135 assert fallback_used is True # noqa: S101
136136 assert result .tripwire_triggered is False # noqa: S101
137+
138+
139+ @pytest .mark .asyncio
140+ async def test_moderation_uses_sync_context_client () -> None :
141+ """Moderation should support synchronous OpenAI clients from context."""
142+ from openai import OpenAI
143+
144+ # Track whether sync context client was used
145+ sync_client_used = False
146+
147+ def track_sync_create (** _ : Any ) -> Any :
148+ nonlocal sync_client_used
149+ sync_client_used = True
150+
151+ class _Result :
152+ def model_dump (self ) -> dict [str , Any ]:
153+ return {"categories" : {"hate" : False , "violence" : False }}
154+
155+ return SimpleNamespace (results = [_Result ()])
156+
157+ # Create a sync context client
158+ sync_client = OpenAI (api_key = "test-sync-key" , base_url = "https://api.openai.com/v1" )
159+ sync_client .moderations = SimpleNamespace (create = track_sync_create ) # type: ignore[assignment]
160+
161+ ctx = SimpleNamespace (guardrail_llm = sync_client )
162+
163+ cfg = ModerationCfg (categories = [Category .HATE , Category .VIOLENCE ])
164+ result = await moderation (ctx , "test text" , cfg )
165+
166+ # Verify the sync context client was used (via asyncio.to_thread)
167+ assert sync_client_used is True # noqa: S101
168+ assert result .tripwire_triggered is False # noqa: S101
169+
170+
171+ @pytest .mark .asyncio
172+ async def test_moderation_falls_back_for_azure_clients (monkeypatch : pytest .MonkeyPatch ) -> None :
173+ """Moderation should fall back to OpenAI client for Azure clients (no moderation endpoint)."""
174+ try :
175+ from openai import AsyncAzureOpenAI , NotFoundError
176+ except ImportError :
177+ pytest .skip ("Azure OpenAI not available" )
178+
179+ # Track whether fallback was used
180+ fallback_used = False
181+
182+ async def track_fallback_create (** _ : Any ) -> Any :
183+ nonlocal fallback_used
184+ fallback_used = True
185+
186+ class _Result :
187+ def model_dump (self ) -> dict [str , Any ]:
188+ return {"categories" : {"hate" : False , "violence" : False }}
189+
190+ return SimpleNamespace (results = [_Result ()])
191+
192+ # Mock the fallback client
193+ fallback_client = SimpleNamespace (moderations = SimpleNamespace (create = track_fallback_create ))
194+ monkeypatch .setattr ("guardrails.checks.text.moderation._get_moderation_client" , lambda : fallback_client )
195+
196+ # Create a mock httpx.Response for NotFoundError
197+ mock_response = SimpleNamespace (
198+ status_code = 404 ,
199+ headers = {},
200+ text = "404 page not found" ,
201+ json = lambda : {"error" : {"message" : "Not found" , "type" : "invalid_request_error" }},
202+ )
203+
204+ # Create an Azure context client that raises NotFoundError for moderation
205+ async def raise_not_found (** _ : Any ) -> Any :
206+ raise NotFoundError ("404 page not found" , response = mock_response , body = None ) # type: ignore[arg-type]
207+
208+ azure_client = AsyncAzureOpenAI (
209+ api_key = "test-azure-key" ,
210+ api_version = "2024-02-01" ,
211+ azure_endpoint = "https://test.openai.azure.com" ,
212+ )
213+ azure_client .moderations = SimpleNamespace (create = raise_not_found ) # type: ignore[assignment]
214+
215+ ctx = SimpleNamespace (guardrail_llm = azure_client )
216+
217+ cfg = ModerationCfg (categories = [Category .HATE , Category .VIOLENCE ])
218+ result = await moderation (ctx , "test text" , cfg )
219+
220+ # Verify the fallback client was used (not the Azure one)
221+ assert fallback_used is True # noqa: S101
222+ assert result .tripwire_triggered is False # noqa: S101
0 commit comments