Skip to content

Commit b2d7a81

Browse files
authored
Add unit tests (#17)
1 parent b36f902 commit b2d7a81

24 files changed

+3815
-81
lines changed

pyproject.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ dev = [
5858
"pymdown-extensions>=10.0.0",
5959
"coverage>=7.8.0",
6060
"hypothesis>=6.131.20",
61+
"pytest-cov>=6.3.0",
6162
]
6263

6364
[tool.uv.workspace]
@@ -103,8 +104,24 @@ convention = "google"
103104
[tool.ruff.format]
104105
docstring-code-format = true
105106

107+
[tool.coverage.run]
108+
source = ["guardrails"]
109+
omit = [
110+
"src/guardrails/evals/*",
111+
]
112+
106113
[tool.mypy]
107114
strict = true
108115
disallow_incomplete_defs = false
109116
disallow_untyped_defs = false
110117
disallow_untyped_calls = false
118+
exclude = [
119+
"examples",
120+
"src/guardrails/evals",
121+
]
122+
123+
[tool.pyright]
124+
ignore = [
125+
"examples",
126+
"src/guardrails/evals",
127+
]

tests/conftest.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""Shared pytest fixtures for guardrails tests.
2+
3+
These fixtures provide deterministic test environments by stubbing the OpenAI
4+
client library, seeding environment variables, and preventing accidental live
5+
network activity during the suite.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import logging
11+
import sys
12+
import types
13+
from collections.abc import Iterator
14+
from dataclasses import dataclass
15+
from types import SimpleNamespace
16+
from typing import Any
17+
18+
import pytest
19+
20+
21+
class _StubOpenAIBase:
22+
"""Base stub with attribute bag behaviour for OpenAI client classes."""
23+
24+
def __init__(self, **kwargs: Any) -> None:
25+
self._client_kwargs = kwargs
26+
self.chat = SimpleNamespace()
27+
self.responses = SimpleNamespace()
28+
self.api_key = kwargs.get("api_key", "test-key")
29+
self.base_url = kwargs.get("base_url")
30+
self.organization = kwargs.get("organization")
31+
self.timeout = kwargs.get("timeout")
32+
self.max_retries = kwargs.get("max_retries")
33+
34+
def __getattr__(self, item: str) -> Any:
35+
"""Return None for unknown attributes to emulate real client laziness."""
36+
return None
37+
38+
39+
class _StubAsyncOpenAI(_StubOpenAIBase):
40+
"""Stub asynchronous OpenAI client."""
41+
42+
43+
class _StubSyncOpenAI(_StubOpenAIBase):
44+
"""Stub synchronous OpenAI client."""
45+
46+
47+
@dataclass(frozen=True, slots=True)
48+
class _DummyResponse:
49+
"""Minimal response type with choices and output."""
50+
51+
choices: list[Any] | None = None
52+
output: list[Any] | None = None
53+
output_text: str | None = None
54+
type: str | None = None
55+
delta: str | None = None
56+
57+
58+
_STUB_OPENAI_MODULE = types.ModuleType("openai")
59+
_STUB_OPENAI_MODULE.AsyncOpenAI = _StubAsyncOpenAI
60+
_STUB_OPENAI_MODULE.OpenAI = _StubSyncOpenAI
61+
_STUB_OPENAI_MODULE.AsyncAzureOpenAI = _StubAsyncOpenAI
62+
_STUB_OPENAI_MODULE.AzureOpenAI = _StubSyncOpenAI
63+
_STUB_OPENAI_MODULE.NOT_GIVEN = object()
64+
65+
66+
class APITimeoutError(Exception):
67+
"""Stub API timeout error."""
68+
69+
70+
_STUB_OPENAI_MODULE.APITimeoutError = APITimeoutError
71+
72+
_OPENAI_TYPES_MODULE = types.ModuleType("openai.types")
73+
_OPENAI_TYPES_MODULE.Completion = _DummyResponse
74+
_OPENAI_TYPES_MODULE.Response = _DummyResponse
75+
76+
_OPENAI_CHAT_MODULE = types.ModuleType("openai.types.chat")
77+
_OPENAI_CHAT_MODULE.ChatCompletion = _DummyResponse
78+
_OPENAI_CHAT_MODULE.ChatCompletionChunk = _DummyResponse
79+
80+
_OPENAI_RESPONSES_MODULE = types.ModuleType("openai.types.responses")
81+
_OPENAI_RESPONSES_MODULE.Response = _DummyResponse
82+
_OPENAI_RESPONSES_MODULE.ResponseInputItemParam = dict # type: ignore[attr-defined]
83+
_OPENAI_RESPONSES_MODULE.ResponseOutputItem = dict # type: ignore[attr-defined]
84+
_OPENAI_RESPONSES_MODULE.ResponseStreamEvent = dict # type: ignore[attr-defined]
85+
86+
87+
_OPENAI_RESPONSES_RESPONSE_MODULE = types.ModuleType("openai.types.responses.response")
88+
_OPENAI_RESPONSES_RESPONSE_MODULE.Response = _DummyResponse
89+
90+
91+
class _ResponseTextConfigParam(dict):
92+
"""Stub config param used for response formatting."""
93+
94+
95+
_OPENAI_RESPONSES_MODULE.ResponseTextConfigParam = _ResponseTextConfigParam
96+
97+
sys.modules["openai"] = _STUB_OPENAI_MODULE
98+
sys.modules["openai.types"] = _OPENAI_TYPES_MODULE
99+
sys.modules["openai.types.chat"] = _OPENAI_CHAT_MODULE
100+
sys.modules["openai.types.responses"] = _OPENAI_RESPONSES_MODULE
101+
sys.modules["openai.types.responses.response"] = _OPENAI_RESPONSES_RESPONSE_MODULE
102+
103+
104+
@pytest.fixture(autouse=True)
105+
def stub_openai_module(monkeypatch: pytest.MonkeyPatch) -> Iterator[types.ModuleType]:
106+
"""Provide stub OpenAI module so tests avoid real network-bound clients."""
107+
# Patch imported symbols in guardrails modules
108+
from guardrails import _base_client, client, types as guardrail_types # type: ignore
109+
110+
monkeypatch.setattr(_base_client, "AsyncOpenAI", _StubAsyncOpenAI, raising=False)
111+
monkeypatch.setattr(_base_client, "OpenAI", _StubSyncOpenAI, raising=False)
112+
monkeypatch.setattr(client, "AsyncOpenAI", _StubAsyncOpenAI, raising=False)
113+
monkeypatch.setattr(client, "OpenAI", _StubSyncOpenAI, raising=False)
114+
monkeypatch.setattr(client, "AsyncAzureOpenAI", _StubAsyncOpenAI, raising=False)
115+
monkeypatch.setattr(client, "AzureOpenAI", _StubSyncOpenAI, raising=False)
116+
monkeypatch.setattr(guardrail_types, "AsyncOpenAI", _StubAsyncOpenAI, raising=False)
117+
monkeypatch.setattr(guardrail_types, "OpenAI", _StubSyncOpenAI, raising=False)
118+
monkeypatch.setattr(guardrail_types, "AsyncAzureOpenAI", _StubAsyncOpenAI, raising=False)
119+
monkeypatch.setattr(guardrail_types, "AzureOpenAI", _StubSyncOpenAI, raising=False)
120+
121+
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
122+
123+
yield _STUB_OPENAI_MODULE
124+
125+
126+
@pytest.fixture(autouse=True)
127+
def configure_logging() -> None:
128+
"""Ensure logging defaults to DEBUG for deterministic assertions."""
129+
logging.basicConfig(level=logging.DEBUG)

tests/integration/test_suite.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,7 @@ async def run_test(
378378
else:
379379
# Find the triggered result
380380
triggered_result = next(
381-
(
382-
r
383-
for r in response.guardrail_results.all_results
384-
if r.tripwire_triggered
385-
),
381+
(r for r in response.guardrail_results.all_results if r.tripwire_triggered),
386382
None,
387383
)
388384
info = triggered_result.info if triggered_result else None
@@ -394,9 +390,7 @@ async def run_test(
394390
"details": {"result": info},
395391
},
396392
)
397-
print(
398-
f"❌ {test.name} - Passing case {idx} triggered when it shouldn't"
399-
)
393+
print(f"❌ {test.name} - Passing case {idx} triggered when it shouldn't")
400394
if info:
401395
print(f" Info: {info}")
402396

@@ -427,11 +421,7 @@ async def run_test(
427421
if tripwire_triggered:
428422
# Find the triggered result
429423
triggered_result = next(
430-
(
431-
r
432-
for r in response.guardrail_results.all_results
433-
if r.tripwire_triggered
434-
),
424+
(r for r in response.guardrail_results.all_results if r.tripwire_triggered),
435425
None,
436426
)
437427
info = triggered_result.info if triggered_result else None
@@ -517,17 +507,9 @@ async def run_test_suite(
517507
results["tests"].append(outcome)
518508

519509
# Calculate test status
520-
passing_fails = sum(
521-
1 for c in outcome["passing_cases"] if c["status"] == "FAIL"
522-
)
523-
failing_fails = sum(
524-
1 for c in outcome["failing_cases"] if c["status"] == "FAIL"
525-
)
526-
errors = sum(
527-
1
528-
for c in outcome["passing_cases"] + outcome["failing_cases"]
529-
if c["status"] == "ERROR"
530-
)
510+
passing_fails = sum(1 for c in outcome["passing_cases"] if c["status"] == "FAIL")
511+
failing_fails = sum(1 for c in outcome["failing_cases"] if c["status"] == "FAIL")
512+
errors = sum(1 for c in outcome["passing_cases"] + outcome["failing_cases"] if c["status"] == "ERROR")
531513

532514
if errors > 0:
533515
results["summary"]["error_tests"] += 1
@@ -538,16 +520,8 @@ async def run_test_suite(
538520

539521
# Count case results
540522
total_cases = len(outcome["passing_cases"]) + len(outcome["failing_cases"])
541-
passed_cases = sum(
542-
1
543-
for c in outcome["passing_cases"] + outcome["failing_cases"]
544-
if c["status"] == "PASS"
545-
)
546-
failed_cases = sum(
547-
1
548-
for c in outcome["passing_cases"] + outcome["failing_cases"]
549-
if c["status"] == "FAIL"
550-
)
523+
passed_cases = sum(1 for c in outcome["passing_cases"] + outcome["failing_cases"] if c["status"] == "PASS")
524+
failed_cases = sum(1 for c in outcome["passing_cases"] + outcome["failing_cases"] if c["status"] == "FAIL")
551525
error_cases = errors
552526

553527
results["summary"]["total_cases"] += total_cases
@@ -564,9 +538,7 @@ def print_summary(results: dict[str, Any]) -> None:
564538
print("GUARDRAILS TEST SUMMARY")
565539
print("=" * 50)
566540
print(
567-
f"Tests: {summary['passed_tests']} passed, "
568-
f"{summary['failed_tests']} failed, "
569-
f"{summary['error_tests']} errors",
541+
f"Tests: {summary['passed_tests']} passed, " f"{summary['failed_tests']} failed, " f"{summary['error_tests']} errors",
570542
)
571543
print(
572544
f"Cases: {summary['total_cases']} total, "

tests/unit/checks/test_keywords.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Tests for keyword-based guardrail helpers."""
2+
3+
from __future__ import annotations
4+
5+
import pytest
6+
from pydantic import ValidationError
7+
8+
from guardrails.checks.text.competitors import CompetitorCfg, competitors
9+
from guardrails.checks.text.keywords import KeywordCfg, keywords, match_keywords
10+
from guardrails.types import GuardrailResult
11+
12+
13+
def test_match_keywords_sanitizes_trailing_punctuation() -> None:
14+
"""Ensure keyword sanitization strips trailing punctuation before matching."""
15+
config = KeywordCfg(keywords=["token.", "secret!", "KEY?"])
16+
result = match_keywords("Leaked token appears here.", config, guardrail_name="Test Guardrail")
17+
18+
assert result.tripwire_triggered is True # noqa: S101
19+
assert result.info["sanitized_keywords"] == ["token", "secret", "KEY"] # noqa: S101
20+
assert result.info["matched"] == ["token"] # noqa: S101
21+
assert result.info["guardrail_name"] == "Test Guardrail" # noqa: S101
22+
assert result.info["checked_text"] == "Leaked token appears here." # noqa: S101
23+
24+
25+
def test_match_keywords_deduplicates_case_insensitive_matches() -> None:
26+
"""Repeated matches differing by case should be deduplicated."""
27+
config = KeywordCfg(keywords=["Alert"])
28+
result = match_keywords("alert ALERT Alert", config, guardrail_name="Keyword Filter")
29+
30+
assert result.tripwire_triggered is True # noqa: S101
31+
assert result.info["matched"] == ["alert"] # noqa: S101
32+
33+
34+
@pytest.mark.asyncio
35+
async def test_keywords_guardrail_wraps_match_keywords() -> None:
36+
"""Async guardrail should mirror match_keywords behaviour."""
37+
config = KeywordCfg(keywords=["breach"])
38+
result = await keywords(ctx=None, data="Potential breach detected", config=config)
39+
40+
assert isinstance(result, GuardrailResult) # noqa: S101
41+
assert result.tripwire_triggered is True # noqa: S101
42+
assert result.info["guardrail_name"] == "Keyword Filter" # noqa: S101
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_competitors_uses_keyword_matching() -> None:
47+
"""Competitors guardrail delegates to keyword matching with distinct name."""
48+
config = CompetitorCfg(keywords=["ACME Corp"])
49+
result = await competitors(ctx=None, data="Comparing against ACME Corp today", config=config)
50+
51+
assert result.tripwire_triggered is True # noqa: S101
52+
assert result.info["guardrail_name"] == "Competitors" # noqa: S101
53+
assert result.info["matched"] == ["ACME Corp"] # noqa: S101
54+
55+
56+
def test_keyword_cfg_requires_non_empty_keywords() -> None:
57+
"""KeywordCfg should enforce at least one keyword."""
58+
with pytest.raises(ValidationError):
59+
KeywordCfg(keywords=[])
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_keywords_does_not_trigger_on_benign_text() -> None:
64+
"""Guardrail should not trigger when no keywords are present."""
65+
config = KeywordCfg(keywords=["restricted"])
66+
result = await keywords(ctx=None, data="Safe content", config=config)
67+
68+
assert result.tripwire_triggered is False # noqa: S101

0 commit comments

Comments
 (0)