Skip to content

Commit 9a9620c

Browse files
committed
adding many tests for validation and tool creation in general, fix some issues that came up
1 parent 99c8e1d commit 9a9620c

File tree

5 files changed

+232
-56
lines changed

5 files changed

+232
-56
lines changed

src/fenic/core/mcp/_tools.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from fenic.core._logical_plan.plans.base import LogicalPlan
1414
from fenic.core._utils.type_inference import infer_pytype_from_dtype
1515
from fenic.core.error import PlanError
16-
from fenic.core.mcp._validators import get_param_validator, maybe_get_param_validator
16+
from fenic.core.mcp._validators import get_param_validator
1717
from fenic.core.mcp.types import (
1818
BoundToolParam,
1919
TableFormat,
@@ -85,9 +85,10 @@ def bind_tool(
8585
try:
8686
validator = get_param_validator(validator_name)
8787
if unresolved_expr.data_type not in validator.data_types():
88+
supported_data_types = ", ".join([str(dt) for dt in validator.data_types()])
8889
raise PlanError(
89-
f"Param Validator {validator_name} supports data types {validator.data_types()}, "
90-
f"but the parameter {unresolved_expr_name} has data type {unresolved_expr.data_type}."
90+
f"Param Validator `{validator_name}` supports data types ({supported_data_types}), "
91+
f"but the parameter `{unresolved_expr_name}` has data type {unresolved_expr.data_type}."
9192
)
9293
validators.append(validator)
9394
except KeyError:
@@ -132,19 +133,11 @@ def _infer_base_type(p: BoundToolParam):
132133
if isinstance(p.data_type, ArrayType):
133134
return list[literal_type] # type: ignore[valid-type]
134135
return literal_type
136+
if isinstance(p.data_type, ArrayType):
137+
inner_type = infer_pytype_from_dtype(p.data_type.element_type)
138+
return list[inner_type] # type: ignore[valid-type]
135139
return infer_pytype_from_dtype(p.data_type)
136140

137-
def _wrap_with_validator(base_t, validator_name: Optional[str]):
138-
if not validator_name:
139-
return base_t
140-
pv = maybe_get_param_validator(validator_name)
141-
if pv is None:
142-
return base_t
143-
def _wrap(v, _pv=pv):
144-
_pv.validate(v)
145-
return v
146-
return TypingAnnotated[base_t, AfterValidator(_wrap)] # type: ignore[valid-type]
147-
148141
def _field_kwargs(p: BoundToolParam, include_default: bool) -> dict:
149142
kwargs: dict = {"description": p.description}
150143
constraints = p.constraints

src/fenic/core/mcp/_validators.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
2-
from typing import Dict, List, Optional, Protocol, Union, runtime_checkable
2+
from typing import Dict, List, Protocol, Union, runtime_checkable
33

4+
from fenic._polars_plugins import py_validate_regex # noqa: F401
45
from fenic.core.error import (
56
ValidationError,
67
)
@@ -21,7 +22,7 @@ def data_types(self) -> List[DataType]:
2122
"""The data types that the validator operates on."""
2223
...
2324

24-
def validate(self, value: Union[str, int, float, bool, list, dict]) -> bool:
25+
def validate(self, value: Union[str, int, float, bool, list, dict]):
2526
"""Validate an argument value.
2627
2728
Args:
@@ -62,14 +63,6 @@ def validate(self, user_query: str):
6263
if len(query) > MAX_REGEX_LENGTH:
6364
raise ValidationError(f"Regex too long (>{MAX_REGEX_LENGTH} characters)")
6465

65-
# Support /pattern/flags and capture flags
66-
query, flags = self._strip_slash_delimiters(query)
67-
unsupported_flags = {f for f in flags if f not in {"i", "m", "s", "x"}}
68-
if unsupported_flags:
69-
raise ValidationError(
70-
f"Unsupported regex flags: {''.join(sorted(unsupported_flags))}"
71-
)
72-
7366
# Strip inline flags at start like (?i), (?m), combined, to avoid duplication
7467
query = re.sub(r"^\(\?[aiLmsux]+\)", "", query)
7568

@@ -89,15 +82,16 @@ def validate(self, user_query: str):
8982
except ValueError:
9083
raise ValidationError("Invalid quantifier bounds") from None
9184
if m_val > MAX_QUANTIFIER_VALUE or n_val > MAX_QUANTIFIER_VALUE:
92-
raise ValidationError("Quantifier bounds too large")
85+
raise ValidationError(f"Quantifier bounds {m_val} or {n_val} > {MAX_QUANTIFIER_VALUE}")
9386
if n and n_val < m_val:
94-
raise ValidationError("Quantifier upper bound less than lower bound")
87+
raise ValidationError(f"Quantifier upper bound {n_val} < lower bound {m_val}")
9588

9689
# Limit alternations
97-
if query.count("|") > MAX_ALTERNATIONS:
98-
raise ValidationError("Too many alternations in regex")
90+
alternations = query.count("|")
91+
if alternations > MAX_ALTERNATIONS:
92+
raise ValidationError(f"Too many alternations ({alternations} > {MAX_ALTERNATIONS})")
9993

100-
# Disallow backreferences (simple and robust detection)
94+
# Disallow backreferences
10195
if any(f"\\{d}" in query for d in "123456789"):
10296
raise ValidationError("Backreferences are not supported")
10397

@@ -121,11 +115,11 @@ def validate(self, user_query: str):
121115
if re.search(r"\{\s*\d+\s*,\s*\d+\s*,", query):
122116
raise ValidationError("Invalid quantifier syntax")
123117

124-
# Ensure it compiles in Python as a basic sanity check
118+
# Final check, ensure that the regex is valid for `rlike`
125119
try:
126-
re.compile(query)
127-
except re.error as err:
128-
raise ValidationError(f"Invalid regex syntax: {err}") from None
120+
py_validate_regex(query)
121+
except Exception as err:
122+
raise ValidationError(f"Invalid regex syntax: {query}") from err
129123

130124
return
131125

@@ -146,20 +140,6 @@ def _is_balanced(self, s: str, open_char: str, close_char: str) -> bool:
146140
i += 1
147141
return depth == 0
148142

149-
150-
def _strip_slash_delimiters(self, pattern: str) -> tuple[str, set[str]]:
151-
"""Support /pattern/flags syntax; return (pattern, flags).
152-
153-
Only recognize i,m,s,x flags; others are rejected later.
154-
"""
155-
if len(pattern) >= 2 and pattern.startswith("/") and pattern.rfind("/") > 0:
156-
last = pattern.rfind("/")
157-
core = pattern[1:last]
158-
flags = set(pattern[last + 1 :].lower())
159-
return core, flags
160-
return pattern, set()
161-
162-
163143
# -- Registry for reusable ParamValidators --
164144
_PARAM_VALIDATOR_REGISTRY: Dict[str, ParamValidator] = {}
165145

@@ -186,11 +166,5 @@ def get_param_validator(name: str) -> ParamValidator:
186166
raise KeyError(f"No ParamValidator registered under name '{name}'") from err
187167

188168

189-
def maybe_get_param_validator(name: Optional[str]) -> Optional[ParamValidator]:
190-
if name is None:
191-
return None
192-
return get_param_validator(name)
193-
194-
195169
# Pre-register common validators
196170
register_param_validator("regex", RegexValidator())

tests/api/mcp/test_server.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,58 @@
55

66
from fenic import SystemTool, SystemToolConfig
77
from fenic.api.mcp._tool_generation_utils import auto_generate_system_tools_from_tables
8+
from fenic.api.functions import col, tool_param
89
from fenic.api.mcp.server import create_mcp_server
910
from fenic.api.session.session import Session
1011
from fenic.core._utils.misc import to_snake_case
12+
from fenic.core.mcp._tools import bind_tool
13+
from fenic.core.mcp.types import ToolParam, ToolParamConstraints
14+
from fenic.core.types.datatypes import ArrayType, IntegerType, StringType
1115
from tests.api.mcp.utils import create_table_with_rows
1216

1317

18+
def test_server_generation_with_parameterized_tools(local_session: Session):
19+
pytest.importorskip("fastmcp")
20+
df = local_session.create_dataframe({"city": ["SF"], "age": [10], "user_name": ["Alice"]})
21+
query = df.filter(
22+
(col("city") == tool_param("city_name", StringType))
23+
& (col("age") >= tool_param("age", IntegerType))
24+
& (col("user_name").is_in(tool_param("user_names", ArrayType(StringType))))
25+
)._logical_plan
26+
27+
parameterized_tool = bind_tool(
28+
name="tool_x",
29+
description="table one",
30+
params=[
31+
ToolParam(name="city_name", description="City name", constraints=ToolParamConstraints(pattern="^SF$")),
32+
ToolParam(name="age", description="Age", constraints=ToolParamConstraints(gt=0, lt=120, multiple_of=2)),
33+
ToolParam(name="user_names", description="User names", constraints=ToolParamConstraints(min_length=1, max_length=5)),
34+
],
35+
result_limit=10,
36+
query=query,
37+
)
38+
39+
server = create_mcp_server(local_session, "Test Server", parameterized_tools=[parameterized_tool])
40+
server_tools = asyncio.run(server.mcp.get_tools())
41+
assert len(server_tools) == 1
42+
parameter_schema = server_tools["tool_x"].parameters['properties']
43+
city_name_param = parameter_schema['city_name']
44+
assert city_name_param['type'] == 'string'
45+
assert city_name_param['pattern'] == '^SF$'
46+
assert city_name_param['description'] == "City name"
47+
age_param = parameter_schema['age']
48+
assert age_param['type'] == 'integer'
49+
assert age_param['exclusiveMinimum'] == 0
50+
assert age_param['exclusiveMaximum'] == 120
51+
assert age_param['multipleOf'] == 2
52+
assert age_param['description'] == "Age"
53+
user_names_param = parameter_schema['user_names']
54+
assert user_names_param['type'] == 'array'
55+
assert user_names_param['items']['type'] == 'string'
56+
assert user_names_param['maxItems'] == 5
57+
assert user_names_param['minItems'] == 1
58+
assert user_names_param['description'] == "User names"
59+
1460
def test_server_generation(local_session: Session):
1561
pytest.importorskip("fastmcp")
1662
create_table_with_rows(local_session, "t1", [1, 2, 3], description="table one")

tests/core/mcp/test_tools.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
import re
12

23
import pytest
34
from pydantic import BaseModel
45
from pydantic import ValidationError as PydValidationError
56

67
from fenic.api.functions import col, tool_param
7-
from fenic.core.error import PlanError
8+
from fenic.core.error import PlanError, ValidationError
89
from fenic.core.mcp._tools import bind_tool, create_pydantic_model_for_tool
9-
from fenic.core.mcp.types import ToolParam
10-
from fenic.core.types.datatypes import IntegerType, StringType
10+
from fenic.core.mcp.types import ToolParam, ToolParamConstraints
11+
from fenic.core.types.datatypes import ArrayType, IntegerType, StringType
1112

1213

1314
def test_toolparam_required_and_default_validation():
@@ -51,6 +52,103 @@ def test_resolve_tool_validates_unresolved_params(local_session):
5152
query=query,
5253
)
5354

55+
def test_resolve_tool_validates_mistyped_validators(local_session):
56+
df = local_session.create_dataframe({"name": ["Alice", "Bob"], "age": [25, 30], "city": ["SF", "SEA"]})
57+
query = df.filter((col("age") >= tool_param("min_age", IntegerType)) & (col("city") == tool_param("city_name", StringType)))._logical_plan
58+
59+
with pytest.raises(PlanError, match="Param Validator `regex` supports data types \(StringType\), but the parameter `min_age` has data type IntegerType."):
60+
bind_tool(
61+
name="users_by_city",
62+
description="Filter users",
63+
params=[
64+
ToolParam(name="min_age", description="Minimum age", validator_names=["regex"]),
65+
ToolParam(name="city_name", description="City name", validator_names=["regex"]),
66+
],
67+
result_limit=50,
68+
query=query,
69+
)
70+
71+
def test_resolve_tool_validates_missing_validators(local_session):
72+
df = local_session.create_dataframe({"name": ["Alice", "Bob"], "age": [25, 30], "city": ["SF", "SEA"]})
73+
query = df.filter((col("age") >= tool_param("min_age", IntegerType)) & (col("city") == tool_param("city_name", StringType)))._logical_plan
74+
75+
with pytest.raises(PlanError, match="Could not find a ParamValidator for the following validator names: \['non_existent'\]"):
76+
bind_tool(
77+
name="users_by_city",
78+
description="Filter users",
79+
params=[
80+
ToolParam(name="min_age", description="Minimum age"),
81+
ToolParam(name="city_name", description="City name", validator_names=["non_existent"]),
82+
],
83+
result_limit=50,
84+
query=query,
85+
)
86+
87+
def test_create_pydantic_model_for_tool_applies_validators(local_session):
88+
df = local_session.create_dataframe({"name": ["Alice", "Bob"], "age": [25, 30], "city": ["SF", "SEA"]})
89+
query = df.filter(
90+
(col("age") >= tool_param("min_age", IntegerType)) &
91+
(col("city") == tool_param("city_name", StringType))
92+
)._logical_plan
93+
94+
tool = bind_tool(
95+
name="users_by_city",
96+
description="Filter users",
97+
params=[
98+
ToolParam(name="min_age", description="Minimum age"),
99+
ToolParam(name="city_name", description="City name", validator_names=["regex"]),
100+
],
101+
result_limit=50,
102+
query=query,
103+
)
104+
105+
Model: type[BaseModel] = create_pydantic_model_for_tool(tool)
106+
107+
with pytest.raises(ValidationError, match="Unbalanced curly braces"):
108+
Model(city_name="{+---", min_age=25)
109+
110+
with pytest.raises(ValidationError, match="Too many alternations \(21 > 20\)"):
111+
Model(city_name="SF|SEA|OAK|PHX|LAS|ORD|XRD|PRD|IAD|CRD|FRA|LON|UMEA|BOS|YYZ|DOG|BAT|BAN|LAP|LAX|TYO|HND", min_age=25)
112+
113+
114+
def test_create_pydantic_model_for_tool_applies_field_validators(local_session):
115+
df = local_session.create_dataframe({"city": ["SF"], "age": [10], "user_name": ["Alice"]})
116+
query = df.filter(
117+
(col("city") == tool_param("city_name", StringType))
118+
& (col("age") >= tool_param("age", IntegerType))
119+
& (col("user_name").is_in(tool_param("user_names", ArrayType(StringType))))
120+
)._logical_plan
121+
122+
tool = bind_tool(
123+
name="tool_x",
124+
description="",
125+
params=[
126+
ToolParam(name="city_name", description="City name", constraints=ToolParamConstraints(pattern="^SF$")),
127+
ToolParam(name="age", description="Age", constraints=ToolParamConstraints(gt=0, lt=120, multiple_of=2)),
128+
ToolParam(name="user_names", description="User names", constraints=ToolParamConstraints(min_length=1, max_length=5)),
129+
],
130+
result_limit=10,
131+
query=query,
132+
)
133+
134+
Model: type[BaseModel] = create_pydantic_model_for_tool(tool)
135+
#should pass validation
136+
Model(city_name="SF", age=10, user_names=["Alice", "Bob"])
137+
with pytest.raises(PydValidationError, match=re.escape("String should match pattern '^SF$'")):
138+
Model(city_name="SEA", age=10, user_names=["Alice", "Bob"])
139+
140+
with pytest.raises(PydValidationError, match=re.escape("Input should be greater than 0")):
141+
Model(city_name="SF", age=0, user_names=["Alice", "Bob"])
142+
143+
with pytest.raises(PydValidationError, match=re.escape("Input should be a multiple of 2")):
144+
Model(city_name="SF", age=11, user_names=["Alice", "Bob"])
145+
146+
with pytest.raises(PydValidationError, match=re.escape("List should have at most 5 items after validation, not 6")):
147+
Model(city_name="SF", age=10, user_names=["Alice", "Bob", "Charlie", "David", "Eve", "Frank"])
148+
149+
with pytest.raises(PydValidationError, match=re.escape("List should have at least 1 item after validation, not 0")):
150+
Model(city_name="SF", age=10, user_names=[])
151+
54152

55153
def test_create_pydantic_model_for_tool_defaults_and_required(local_session):
56154
df = local_session.create_dataframe({"city": ["SF"], "age": [10]})

tests/core/mcp/test_validators.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import pytest
2+
3+
from fenic.core.error import ValidationError
4+
from fenic.core.mcp._validators import (
5+
RegexValidator,
6+
get_param_validator,
7+
register_param_validator,
8+
)
9+
10+
11+
def test_regex_validator_accepts_simple_pattern():
12+
v = RegexValidator()
13+
v.validate("foo|bar")
14+
15+
16+
def test_regex_validator_supports_slash_delimiters_and_flags():
17+
v = RegexValidator()
18+
v.validate("/foo.*/i")
19+
20+
21+
@pytest.mark.parametrize(
22+
"pattern",
23+
[
24+
" ", # whitespace-only
25+
"(", # unbalanced paren
26+
"[a-", # unbalanced bracket
27+
"{1,2,3}", # invalid quantifier syntax
28+
],
29+
)
30+
def test_regex_validator_rejects_basic_invalid_patterns(pattern):
31+
v = RegexValidator()
32+
with pytest.raises(ValidationError):
33+
v.validate(pattern)
34+
35+
36+
@pytest.mark.parametrize(
37+
"pattern",
38+
[
39+
r"(.+)+", # nested quantifier
40+
r"(.*)+", # nested quantifier
41+
r"(?:.+){1001}", # excessive bounded quantifier
42+
r"(a|b|c|d|e|f|g|h|i|j|k|l|m|n|o|p|q|r|s|t|u|v|w|x|y|z|aa|ab|ac)", # too many alternations
43+
r"\1", # backreference
44+
r"(?<=a)b", # lookbehind
45+
r"(\.\*){2,1}", # m > n for matching between m and n repeats of a character
46+
],
47+
)
48+
def test_regex_validator_rejects_redos_like_and_unsupported_constructs(pattern):
49+
v = RegexValidator()
50+
with pytest.raises(ValidationError):
51+
v.validate(pattern)
52+
53+
54+
def test_registry_has_default_regex_validator():
55+
v = get_param_validator("regex")
56+
assert isinstance(v, RegexValidator)
57+
58+
59+
def test_registry_register_and_lookup_custom():
60+
import uuid
61+
62+
unique_name = f"custom_regex_{uuid.uuid4().hex}"
63+
register_param_validator(unique_name, RegexValidator())
64+
v = get_param_validator(unique_name)
65+
assert isinstance(v, RegexValidator)

0 commit comments

Comments
 (0)