Skip to content

Commit 1660d88

Browse files
Copilotborchero
andauthored
feat: Implement inner validation rules for Array columns (#222)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: borchero <22455425+borchero@users.noreply.github.com> Co-authored-by: Oliver Borchert <oliver.borchert@quantco.com>
1 parent 355acf8 commit 1660d88

File tree

3 files changed

+79
-55
lines changed

3 files changed

+79
-55
lines changed

dataframely/columns/array.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ._base import Check, Column
1717
from ._registry import column_from_dict, register
18-
from .struct import Struct
18+
from .list import _list_primary_key_check
1919

2020
if sys.version_info >= (3, 11):
2121
from typing import Self
@@ -42,7 +42,7 @@ def __init__(
4242
):
4343
"""
4444
Args:
45-
inner: The inner column type. No validation rules on the inner type are supported yet.
45+
inner: The inner column type.
4646
shape: The shape of the array.
4747
nullable: Whether this column may contain null values.
4848
primary_key: Whether this column is part of the primary key of the schema.
@@ -64,23 +64,6 @@ def __init__(
6464
names, the specified alias is the only valid name.
6565
metadata: A dictionary of metadata to attach to the column.
6666
"""
67-
if inner.primary_key or (
68-
isinstance(inner, Struct)
69-
and any(col.primary_key for col in inner.inner.values())
70-
):
71-
raise ValueError(
72-
"`primary_key=True` is not yet supported for inner types of the Array type."
73-
)
74-
75-
# We disallow validation rules on the inner type since Polars arrays currently don't support .eval(). Converting
76-
# to a list and calling .list.eval() is possible, however, since the shape can have multiple axes, the recursive
77-
# conversion could have significant performance impact. Hence, we simply disallow inner validation rules.
78-
# Another option would be to allow validation rules only for sampling, but not enforce them.
79-
if inner.validation_rules(pl.lit(None)):
80-
raise ValueError(
81-
"Validation rules on the inner type of Array are not yet supported."
82-
)
83-
8467
super().__init__(
8568
nullable=nullable,
8669
primary_key=False,
@@ -95,6 +78,24 @@ def __init__(
9578
def dtype(self) -> pl.DataType:
9679
return pl.Array(self.inner.dtype, self.shape)
9780

81+
def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
82+
inner_rules = {
83+
f"inner_{rule_name}": expr.arr.eval(inner_expr).arr.all()
84+
for rule_name, inner_expr in self.inner.validation_rules(
85+
pl.element()
86+
).items()
87+
}
88+
89+
array_rules: dict[str, pl.Expr] = {}
90+
if (rule := _list_primary_key_check(expr.arr, self.inner)) is not None:
91+
array_rules["primary_key"] = rule
92+
93+
return {
94+
**super().validation_rules(expr),
95+
**array_rules,
96+
**inner_rules,
97+
}
98+
9899
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
99100
# NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future.
100101
raise NotImplementedError("SQL column cannot have 'Array' type.")

dataframely/columns/list.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import Any, cast
99

1010
import polars as pl
11+
from polars.expr.array import ExprArrayNameSpace
12+
from polars.expr.list import ExprListNameSpace
1113

1214
from dataframely._compat import pa, sa, sa_TypeEngine
1315
from dataframely._polars import PolarsDataType
@@ -97,29 +99,8 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
9799
}
98100

99101
list_rules: dict[str, pl.Expr] = {}
100-
if self.inner.primary_key:
101-
list_rules["primary_key"] = ~expr.list.eval(
102-
pl.element().is_duplicated()
103-
).list.any()
104-
elif isinstance(self.inner, Struct) and any(
105-
col.primary_key for col in self.inner.inner.values()
106-
):
107-
primary_key_columns = [
108-
name for name, col in self.inner.inner.items() if col.primary_key
109-
]
110-
# NOTE: We optimize for a single primary key column here as it is much
111-
# faster to run duplication checks for non-struct types in polars 1.22.
112-
if len(primary_key_columns) == 1:
113-
list_rules["primary_key"] = ~expr.list.eval(
114-
pl.element().struct.field(primary_key_columns[0]).is_duplicated()
115-
).list.any()
116-
else:
117-
list_rules["primary_key"] = ~expr.list.eval(
118-
pl.struct(
119-
pl.element().struct.field(primary_key_columns)
120-
).is_duplicated()
121-
).list.any()
122-
102+
if (rule := _list_primary_key_check(expr.list, self.inner)) is not None:
103+
list_rules["primary_key"] = rule
123104
if self.min_length is not None:
124105
list_rules["min_length"] = (
125106
pl.when(expr.is_null())
@@ -187,3 +168,35 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]:
187168
def from_dict(cls, data: dict[str, Any]) -> Self:
188169
data["inner"] = column_from_dict(data["inner"])
189170
return super().from_dict(data)
171+
172+
173+
def _list_primary_key_check(
174+
list_expr: ExprListNameSpace | ExprArrayNameSpace, inner: Column
175+
) -> pl.Expr | None:
176+
def list_any(expr: pl.Expr) -> pl.Expr:
177+
if isinstance(list_expr, ExprListNameSpace):
178+
return expr.list.any()
179+
return expr.arr.any()
180+
181+
if inner.primary_key:
182+
return ~list_expr.eval(pl.element().is_duplicated()).pipe(list_any)
183+
elif isinstance(inner, Struct) and any(
184+
col.primary_key for col in inner.inner.values()
185+
):
186+
primary_key_columns = [
187+
name for name, col in inner.inner.items() if col.primary_key
188+
]
189+
# NOTE: We optimize for a single primary key column here as it is much
190+
# faster to run duplication checks for non-struct types in polars 1.22.
191+
if len(primary_key_columns) == 1:
192+
return ~list_expr.eval(
193+
pl.element().struct.field(primary_key_columns[0]).is_duplicated()
194+
).pipe(list_any)
195+
else:
196+
return ~list_expr.eval(
197+
pl.struct(
198+
pl.element().struct.field(primary_key_columns)
199+
).is_duplicated()
200+
).pipe(list_any)
201+
202+
return None

tests/column_types/test_array.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import dataframely as dy
88
from dataframely.columns._base import Column
9-
from dataframely.testing import create_schema
9+
from dataframely.testing import create_schema, validation_mask
1010

1111

1212
@pytest.mark.parametrize(
@@ -132,20 +132,30 @@ def test_nested_array() -> None:
132132
)
133133

134134

135-
def test_array_with_inner_pk() -> None:
136-
with pytest.raises(ValueError):
137-
column = dy.Array(dy.String(primary_key=True), 2)
138-
create_schema(
139-
"test",
140-
{"a": column},
141-
)
135+
def test_array_with_rules() -> None:
136+
schema = create_schema(
137+
"test", {"a": dy.Array(dy.String(min_length=2, nullable=False), 1)}
138+
)
139+
df = pl.DataFrame(
140+
{"a": [["ab"], ["a"], [None]]},
141+
schema={"a": pl.Array(pl.String, 1)},
142+
)
143+
_, failures = schema.filter(df)
144+
assert validation_mask(df, failures).to_list() == [True, False, False]
145+
assert failures.counts() == {"a|inner_nullability": 1, "a|inner_min_length": 1}
142146

143147

144-
def test_array_with_rules() -> None:
145-
with pytest.raises(ValueError):
146-
create_schema(
147-
"test", {"a": dy.Array(dy.String(min_length=2, nullable=False), 1)}
148-
)
148+
def test_array_with_primary_key_rule() -> None:
149+
schema = create_schema(
150+
"test", {"a": dy.Array(dy.String(min_length=2, primary_key=True), 2)}
151+
)
152+
df = pl.DataFrame(
153+
{"a": [["ab", "ab"], ["cd", "de"], ["def", "ghi"]]},
154+
schema={"a": pl.Array(pl.String, 2)},
155+
)
156+
_, failures = schema.filter(df)
157+
assert validation_mask(df, failures).to_list() == [False, True, True]
158+
assert failures.counts() == {"a|primary_key": 1}
149159

150160

151161
def test_outer_nullability() -> None:

0 commit comments

Comments
 (0)