Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions dataframely/columns/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ._base import Check, Column
from ._registry import column_from_dict, register
from .struct import Struct
from .list import _list_primary_key_check

if sys.version_info >= (3, 11):
from typing import Self
Expand All @@ -42,7 +42,7 @@ def __init__(
):
"""
Args:
inner: The inner column type. No validation rules on the inner type are supported yet.
inner: The inner column type.
shape: The shape of the array.
nullable: Whether this column may contain null values.
primary_key: Whether this column is part of the primary key of the schema.
Expand All @@ -64,23 +64,6 @@ def __init__(
names, the specified alias is the only valid name.
metadata: A dictionary of metadata to attach to the column.
"""
if inner.primary_key or (
isinstance(inner, Struct)
and any(col.primary_key for col in inner.inner.values())
):
raise ValueError(
"`primary_key=True` is not yet supported for inner types of the Array type."
)

# We disallow validation rules on the inner type since Polars arrays currently don't support .eval(). Converting
# to a list and calling .list.eval() is possible, however, since the shape can have multiple axes, the recursive
# conversion could have significant performance impact. Hence, we simply disallow inner validation rules.
# Another option would be to allow validation rules only for sampling, but not enforce them.
if inner.validation_rules(pl.lit(None)):
raise ValueError(
"Validation rules on the inner type of Array are not yet supported."
)

super().__init__(
nullable=nullable,
primary_key=False,
Expand All @@ -95,6 +78,24 @@ def __init__(
def dtype(self) -> pl.DataType:
return pl.Array(self.inner.dtype, self.shape)

def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
inner_rules = {
f"inner_{rule_name}": expr.arr.eval(inner_expr).arr.all()
for rule_name, inner_expr in self.inner.validation_rules(
pl.element()
).items()
}

array_rules: dict[str, pl.Expr] = {}
if (rule := _list_primary_key_check(expr.arr, self.inner)) is not None:
array_rules["primary_key"] = rule

return {
**super().validation_rules(expr),
**array_rules,
**inner_rules,
}

def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
# NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future.
raise NotImplementedError("SQL column cannot have 'Array' type.")
Expand Down
59 changes: 36 additions & 23 deletions dataframely/columns/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import Any, cast

import polars as pl
from polars.expr.array import ExprArrayNameSpace
from polars.expr.list import ExprListNameSpace

from dataframely._compat import pa, sa, sa_TypeEngine
from dataframely._polars import PolarsDataType
Expand Down Expand Up @@ -97,29 +99,8 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
}

list_rules: dict[str, pl.Expr] = {}
if self.inner.primary_key:
list_rules["primary_key"] = ~expr.list.eval(
pl.element().is_duplicated()
).list.any()
elif isinstance(self.inner, Struct) and any(
col.primary_key for col in self.inner.inner.values()
):
primary_key_columns = [
name for name, col in self.inner.inner.items() if col.primary_key
]
# NOTE: We optimize for a single primary key column here as it is much
# faster to run duplication checks for non-struct types in polars 1.22.
if len(primary_key_columns) == 1:
list_rules["primary_key"] = ~expr.list.eval(
pl.element().struct.field(primary_key_columns[0]).is_duplicated()
).list.any()
else:
list_rules["primary_key"] = ~expr.list.eval(
pl.struct(
pl.element().struct.field(primary_key_columns)
).is_duplicated()
).list.any()

if (rule := _list_primary_key_check(expr.list, self.inner)) is not None:
list_rules["primary_key"] = rule
if self.min_length is not None:
list_rules["min_length"] = (
pl.when(expr.is_null())
Expand Down Expand Up @@ -187,3 +168,35 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]:
def from_dict(cls, data: dict[str, Any]) -> Self:
data["inner"] = column_from_dict(data["inner"])
return super().from_dict(data)


def _list_primary_key_check(
list_expr: ExprListNameSpace | ExprArrayNameSpace, inner: Column
) -> pl.Expr | None:
def list_any(expr: pl.Expr) -> pl.Expr:
if isinstance(list_expr, ExprListNameSpace):
return expr.list.any()
return expr.arr.any()

if inner.primary_key:
return ~list_expr.eval(pl.element().is_duplicated()).pipe(list_any)
elif isinstance(inner, Struct) and any(
col.primary_key for col in inner.inner.values()
):
primary_key_columns = [
name for name, col in inner.inner.items() if col.primary_key
]
# NOTE: We optimize for a single primary key column here as it is much
# faster to run duplication checks for non-struct types in polars 1.22.
if len(primary_key_columns) == 1:
return ~list_expr.eval(
pl.element().struct.field(primary_key_columns[0]).is_duplicated()
).pipe(list_any)
else:
return ~list_expr.eval(
pl.struct(
pl.element().struct.field(primary_key_columns)
).is_duplicated()
).pipe(list_any)

return None
36 changes: 23 additions & 13 deletions tests/column_types/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import dataframely as dy
from dataframely.columns._base import Column
from dataframely.testing import create_schema
from dataframely.testing import create_schema, validation_mask


@pytest.mark.parametrize(
Expand Down Expand Up @@ -132,20 +132,30 @@ def test_nested_array() -> None:
)


def test_array_with_inner_pk() -> None:
with pytest.raises(ValueError):
column = dy.Array(dy.String(primary_key=True), 2)
create_schema(
"test",
{"a": column},
)
def test_array_with_rules() -> None:
schema = create_schema(
"test", {"a": dy.Array(dy.String(min_length=2, nullable=False), 1)}
)
df = pl.DataFrame(
{"a": [["ab"], ["a"], [None]]},
schema={"a": pl.Array(pl.String, 1)},
)
_, failures = schema.filter(df)
assert validation_mask(df, failures).to_list() == [True, False, False]
assert failures.counts() == {"a|inner_nullability": 1, "a|inner_min_length": 1}


def test_array_with_rules() -> None:
with pytest.raises(ValueError):
create_schema(
"test", {"a": dy.Array(dy.String(min_length=2, nullable=False), 1)}
)
def test_array_with_primary_key_rule() -> None:
schema = create_schema(
"test", {"a": dy.Array(dy.String(min_length=2, primary_key=True), 2)}
)
df = pl.DataFrame(
{"a": [["ab", "ab"], ["cd", "de"], ["def", "ghi"]]},
schema={"a": pl.Array(pl.String, 2)},
)
_, failures = schema.filter(df)
assert validation_mask(df, failures).to_list() == [False, True, True]
assert failures.counts() == {"a|primary_key": 1}


def test_outer_nullability() -> None:
Expand Down
Loading