diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index 2ff1b06..db232ba 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -99,16 +99,17 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: # NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future. raise NotImplementedError("SQL column cannot have 'Array' type.") - def _pyarrow_dtype_of_shape(self, shape: Sequence[int]) -> pa.DataType: + def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field: if shape: size, *rest = shape - return pa.list_(self._pyarrow_dtype_of_shape(rest), size) + inner_type = self._pyarrow_field_of_shape(rest) + return pa.field("item", pa.list_(inner_type, size), nullable=True) else: - return self.inner.pyarrow_dtype + return self.inner.pyarrow_field("item") @property def pyarrow_dtype(self) -> pa.DataType: - return self._pyarrow_dtype_of_shape(self.shape) + return self._pyarrow_field_of_shape(self.shape).type def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: # Sample the inner elements in a flat series diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index 0804cf1..e1dd88b 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -145,7 +145,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: @property def pyarrow_dtype(self) -> pa.DataType: # NOTE: Polars uses `large_list`s by default. - return pa.large_list(self.inner.pyarrow_dtype) + return pa.large_list(self.inner.pyarrow_field("item")) def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: # First, sample the number of items per list element diff --git a/dataframely/columns/struct.py b/dataframely/columns/struct.py index ea6a687..b8aecf9 100644 --- a/dataframely/columns/struct.py +++ b/dataframely/columns/struct.py @@ -112,7 +112,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: @property def pyarrow_dtype(self) -> pa.DataType: - return pa.struct({name: col.pyarrow_dtype for name, col in self.inner.items()}) + return pa.struct([col.pyarrow_field(name) for name, col in self.inner.items()]) def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: series = ( diff --git a/tests/columns/test_pyarrow.py b/tests/columns/test_pyarrow.py index 5ae35ea..9c485d2 100644 --- a/tests/columns/test_pyarrow.py +++ b/tests/columns/test_pyarrow.py @@ -59,13 +59,10 @@ def test_equal_polars_schema_enum(categories: list[str]) -> None: @pytest.mark.parametrize( "inner", - [c() for c in ALL_COLUMN_TYPES] - + [dy.List(t()) for t in ALL_COLUMN_TYPES] - + [ - dy.Array(t() if t == dy.Any else t(nullable=True), 1) - for t in NO_VALIDATION_COLUMN_TYPES - ] - + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES], + [_nullable(c) for c in ALL_COLUMN_TYPES] + + [dy.List(_nullable(t), nullable=True) for t in ALL_COLUMN_TYPES] + + [dy.Array(_nullable(t), 1, nullable=True) for t in NO_VALIDATION_COLUMN_TYPES] + + [dy.Struct({"a": _nullable(t)}, nullable=True) for t in ALL_COLUMN_TYPES], ) def test_equal_polars_schema_list(inner: Column) -> None: schema = create_schema("test", {"a": dy.List(inner, nullable=True)}) @@ -161,6 +158,98 @@ def test_nullability_information_struct(inner: Column, nullable: bool) -> None: assert ("not null" in str(schema.to_pyarrow_schema())) != nullable +@pytest.mark.parametrize("column_type", COLUMN_TYPES) +@pytest.mark.parametrize("inner_nullable", [True, False]) +def test_inner_nullability_struct( + column_type: type[Column], inner_nullable: bool +) -> None: + inner = column_type(nullable=inner_nullable) + schema = create_schema("test", {"a": dy.Struct({"a": inner})}) + pa_schema = schema.to_pyarrow_schema() + struct_field = pa_schema.field("a") + inner_field = struct_field.type[0] + assert inner_field.nullable == inner_nullable + + +@pytest.mark.parametrize("column_type", COLUMN_TYPES) +@pytest.mark.parametrize("inner_nullable", [True, False]) +def test_inner_nullability_list( + column_type: type[Column], inner_nullable: bool +) -> None: + inner = column_type(nullable=inner_nullable) + schema = create_schema("test", {"a": dy.List(inner)}) + pa_schema = schema.to_pyarrow_schema() + list_field = pa_schema.field("a") + inner_field = list_field.type.value_field + assert inner_field.nullable == inner_nullable + + +def test_nested_struct_in_list_preserves_nullability() -> None: + """Test that nested struct fields in lists preserve nullability.""" + schema = create_schema( + "test", + { + "a": dy.List( + dy.Struct( + { + "required": dy.String(nullable=False), + "optional": dy.String(nullable=True), + }, + nullable=True, + ), + nullable=True, + ) + }, + ) + pa_schema = schema.to_pyarrow_schema() + list_field = pa_schema.field("a") + struct_type = list_field.type.value_field.type + assert not struct_type[0].nullable + assert struct_type[1].nullable + + +def test_nested_list_in_struct_preserves_nullability() -> None: + """Test that nested list fields in structs preserve nullability.""" + schema = create_schema( + "test", + { + "a": dy.Struct( + {"list_field": dy.List(dy.String(nullable=False), nullable=True)}, + nullable=True, + ) + }, + ) + pa_schema = schema.to_pyarrow_schema() + struct_field = pa_schema.field("a") + list_type = struct_field.type[0].type + assert not list_type.value_field.nullable + + +def test_deeply_nested_nullability() -> None: + schema = create_schema( + "test", + { + "a": dy.Struct( + { + "nested": dy.Struct( + { + "required": dy.String(nullable=False), + "optional": dy.String(nullable=True), + }, + nullable=True, + ), + }, + nullable=True, + ) + }, + ) + pa_schema = schema.to_pyarrow_schema() + outer_struct = pa_schema.field("a").type + inner_struct = outer_struct[0].type + assert not inner_struct[0].nullable # required field + assert inner_struct[1].nullable # optional field + + def test_multiple_columns() -> None: schema = create_schema( "test", {"a": dy.Int32(nullable=False), "b": dy.Integer(nullable=True)}