diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index db232ba..ddd9e9d 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -15,7 +15,7 @@ from ._base import Check, Column from ._registry import column_from_dict, register -from .struct import Struct +from .list import _list_primary_key_check if sys.version_info >= (3, 11): from typing import Self @@ -42,7 +42,7 @@ def __init__( ): """ Args: - inner: The inner column type. No validation rules on the inner type are supported yet. + inner: The inner column type. shape: The shape of the array. nullable: Whether this column may contain null values. primary_key: Whether this column is part of the primary key of the schema. @@ -64,23 +64,6 @@ def __init__( names, the specified alias is the only valid name. metadata: A dictionary of metadata to attach to the column. """ - if inner.primary_key or ( - isinstance(inner, Struct) - and any(col.primary_key for col in inner.inner.values()) - ): - raise ValueError( - "`primary_key=True` is not yet supported for inner types of the Array type." - ) - - # We disallow validation rules on the inner type since Polars arrays currently don't support .eval(). Converting - # to a list and calling .list.eval() is possible, however, since the shape can have multiple axes, the recursive - # conversion could have significant performance impact. Hence, we simply disallow inner validation rules. - # Another option would be to allow validation rules only for sampling, but not enforce them. - if inner.validation_rules(pl.lit(None)): - raise ValueError( - "Validation rules on the inner type of Array are not yet supported." - ) - super().__init__( nullable=nullable, primary_key=False, @@ -95,6 +78,24 @@ def __init__( def dtype(self) -> pl.DataType: return pl.Array(self.inner.dtype, self.shape) + def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: + inner_rules = { + f"inner_{rule_name}": expr.arr.eval(inner_expr).arr.all() + for rule_name, inner_expr in self.inner.validation_rules( + pl.element() + ).items() + } + + array_rules: dict[str, pl.Expr] = {} + if (rule := _list_primary_key_check(expr.arr, self.inner)) is not None: + array_rules["primary_key"] = rule + + return { + **super().validation_rules(expr), + **array_rules, + **inner_rules, + } + def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: # NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future. raise NotImplementedError("SQL column cannot have 'Array' type.") diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index e1dd88b..5c22d24 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -8,6 +8,8 @@ from typing import Any, cast import polars as pl +from polars.expr.array import ExprArrayNameSpace +from polars.expr.list import ExprListNameSpace from dataframely._compat import pa, sa, sa_TypeEngine from dataframely._polars import PolarsDataType @@ -97,29 +99,8 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: } list_rules: dict[str, pl.Expr] = {} - if self.inner.primary_key: - list_rules["primary_key"] = ~expr.list.eval( - pl.element().is_duplicated() - ).list.any() - elif isinstance(self.inner, Struct) and any( - col.primary_key for col in self.inner.inner.values() - ): - primary_key_columns = [ - name for name, col in self.inner.inner.items() if col.primary_key - ] - # NOTE: We optimize for a single primary key column here as it is much - # faster to run duplication checks for non-struct types in polars 1.22. - if len(primary_key_columns) == 1: - list_rules["primary_key"] = ~expr.list.eval( - pl.element().struct.field(primary_key_columns[0]).is_duplicated() - ).list.any() - else: - list_rules["primary_key"] = ~expr.list.eval( - pl.struct( - pl.element().struct.field(primary_key_columns) - ).is_duplicated() - ).list.any() - + if (rule := _list_primary_key_check(expr.list, self.inner)) is not None: + list_rules["primary_key"] = rule if self.min_length is not None: list_rules["min_length"] = ( pl.when(expr.is_null()) @@ -187,3 +168,35 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]: def from_dict(cls, data: dict[str, Any]) -> Self: data["inner"] = column_from_dict(data["inner"]) return super().from_dict(data) + + +def _list_primary_key_check( + list_expr: ExprListNameSpace | ExprArrayNameSpace, inner: Column +) -> pl.Expr | None: + def list_any(expr: pl.Expr) -> pl.Expr: + if isinstance(list_expr, ExprListNameSpace): + return expr.list.any() + return expr.arr.any() + + if inner.primary_key: + return ~list_expr.eval(pl.element().is_duplicated()).pipe(list_any) + elif isinstance(inner, Struct) and any( + col.primary_key for col in inner.inner.values() + ): + primary_key_columns = [ + name for name, col in inner.inner.items() if col.primary_key + ] + # NOTE: We optimize for a single primary key column here as it is much + # faster to run duplication checks for non-struct types in polars 1.22. + if len(primary_key_columns) == 1: + return ~list_expr.eval( + pl.element().struct.field(primary_key_columns[0]).is_duplicated() + ).pipe(list_any) + else: + return ~list_expr.eval( + pl.struct( + pl.element().struct.field(primary_key_columns) + ).is_duplicated() + ).pipe(list_any) + + return None diff --git a/tests/column_types/test_array.py b/tests/column_types/test_array.py index 4b53911..fdb1c3f 100644 --- a/tests/column_types/test_array.py +++ b/tests/column_types/test_array.py @@ -6,7 +6,7 @@ import dataframely as dy from dataframely.columns._base import Column -from dataframely.testing import create_schema +from dataframely.testing import create_schema, validation_mask @pytest.mark.parametrize( @@ -132,20 +132,30 @@ def test_nested_array() -> None: ) -def test_array_with_inner_pk() -> None: - with pytest.raises(ValueError): - column = dy.Array(dy.String(primary_key=True), 2) - create_schema( - "test", - {"a": column}, - ) +def test_array_with_rules() -> None: + schema = create_schema( + "test", {"a": dy.Array(dy.String(min_length=2, nullable=False), 1)} + ) + df = pl.DataFrame( + {"a": [["ab"], ["a"], [None]]}, + schema={"a": pl.Array(pl.String, 1)}, + ) + _, failures = schema.filter(df) + assert validation_mask(df, failures).to_list() == [True, False, False] + assert failures.counts() == {"a|inner_nullability": 1, "a|inner_min_length": 1} -def test_array_with_rules() -> None: - with pytest.raises(ValueError): - create_schema( - "test", {"a": dy.Array(dy.String(min_length=2, nullable=False), 1)} - ) +def test_array_with_primary_key_rule() -> None: + schema = create_schema( + "test", {"a": dy.Array(dy.String(min_length=2, primary_key=True), 2)} + ) + df = pl.DataFrame( + {"a": [["ab", "ab"], ["cd", "de"], ["def", "ghi"]]}, + schema={"a": pl.Array(pl.String, 2)}, + ) + _, failures = schema.filter(df) + assert validation_mask(df, failures).to_list() == [False, True, True] + assert failures.counts() == {"a|primary_key": 1} def test_outer_nullability() -> None: