Skip to content

Commit 355acf8

Browse files
Copilotborchero
andauthored
fix: Properly set nested nullability when converting to pyarrow (#217)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: borchero <22455425+borchero@users.noreply.github.com> Co-authored-by: Oliver Borchert <oliver.borchert@quantco.com>
1 parent b9b365e commit 355acf8

File tree

4 files changed

+103
-13
lines changed

4 files changed

+103
-13
lines changed

dataframely/columns/array.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,17 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
9999
# NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future.
100100
raise NotImplementedError("SQL column cannot have 'Array' type.")
101101

102-
def _pyarrow_dtype_of_shape(self, shape: Sequence[int]) -> pa.DataType:
102+
def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field:
103103
if shape:
104104
size, *rest = shape
105-
return pa.list_(self._pyarrow_dtype_of_shape(rest), size)
105+
inner_type = self._pyarrow_field_of_shape(rest)
106+
return pa.field("item", pa.list_(inner_type, size), nullable=True)
106107
else:
107-
return self.inner.pyarrow_dtype
108+
return self.inner.pyarrow_field("item")
108109

109110
@property
110111
def pyarrow_dtype(self) -> pa.DataType:
111-
return self._pyarrow_dtype_of_shape(self.shape)
112+
return self._pyarrow_field_of_shape(self.shape).type
112113

113114
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
114115
# Sample the inner elements in a flat series

dataframely/columns/list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
145145
@property
146146
def pyarrow_dtype(self) -> pa.DataType:
147147
# NOTE: Polars uses `large_list`s by default.
148-
return pa.large_list(self.inner.pyarrow_dtype)
148+
return pa.large_list(self.inner.pyarrow_field("item"))
149149

150150
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
151151
# First, sample the number of items per list element

dataframely/columns/struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
112112

113113
@property
114114
def pyarrow_dtype(self) -> pa.DataType:
115-
return pa.struct({name: col.pyarrow_dtype for name, col in self.inner.items()})
115+
return pa.struct([col.pyarrow_field(name) for name, col in self.inner.items()])
116116

117117
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
118118
series = (

tests/columns/test_pyarrow.py

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,10 @@ def test_equal_polars_schema_enum(categories: list[str]) -> None:
5959

6060
@pytest.mark.parametrize(
6161
"inner",
62-
[c() for c in ALL_COLUMN_TYPES]
63-
+ [dy.List(t()) for t in ALL_COLUMN_TYPES]
64-
+ [
65-
dy.Array(t() if t == dy.Any else t(nullable=True), 1)
66-
for t in NO_VALIDATION_COLUMN_TYPES
67-
]
68-
+ [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES],
62+
[_nullable(c) for c in ALL_COLUMN_TYPES]
63+
+ [dy.List(_nullable(t), nullable=True) for t in ALL_COLUMN_TYPES]
64+
+ [dy.Array(_nullable(t), 1, nullable=True) for t in NO_VALIDATION_COLUMN_TYPES]
65+
+ [dy.Struct({"a": _nullable(t)}, nullable=True) for t in ALL_COLUMN_TYPES],
6966
)
7067
def test_equal_polars_schema_list(inner: Column) -> None:
7168
schema = create_schema("test", {"a": dy.List(inner, nullable=True)})
@@ -161,6 +158,98 @@ def test_nullability_information_struct(inner: Column, nullable: bool) -> None:
161158
assert ("not null" in str(schema.to_pyarrow_schema())) != nullable
162159

163160

161+
@pytest.mark.parametrize("column_type", COLUMN_TYPES)
162+
@pytest.mark.parametrize("inner_nullable", [True, False])
163+
def test_inner_nullability_struct(
164+
column_type: type[Column], inner_nullable: bool
165+
) -> None:
166+
inner = column_type(nullable=inner_nullable)
167+
schema = create_schema("test", {"a": dy.Struct({"a": inner})})
168+
pa_schema = schema.to_pyarrow_schema()
169+
struct_field = pa_schema.field("a")
170+
inner_field = struct_field.type[0]
171+
assert inner_field.nullable == inner_nullable
172+
173+
174+
@pytest.mark.parametrize("column_type", COLUMN_TYPES)
175+
@pytest.mark.parametrize("inner_nullable", [True, False])
176+
def test_inner_nullability_list(
177+
column_type: type[Column], inner_nullable: bool
178+
) -> None:
179+
inner = column_type(nullable=inner_nullable)
180+
schema = create_schema("test", {"a": dy.List(inner)})
181+
pa_schema = schema.to_pyarrow_schema()
182+
list_field = pa_schema.field("a")
183+
inner_field = list_field.type.value_field
184+
assert inner_field.nullable == inner_nullable
185+
186+
187+
def test_nested_struct_in_list_preserves_nullability() -> None:
188+
"""Test that nested struct fields in lists preserve nullability."""
189+
schema = create_schema(
190+
"test",
191+
{
192+
"a": dy.List(
193+
dy.Struct(
194+
{
195+
"required": dy.String(nullable=False),
196+
"optional": dy.String(nullable=True),
197+
},
198+
nullable=True,
199+
),
200+
nullable=True,
201+
)
202+
},
203+
)
204+
pa_schema = schema.to_pyarrow_schema()
205+
list_field = pa_schema.field("a")
206+
struct_type = list_field.type.value_field.type
207+
assert not struct_type[0].nullable
208+
assert struct_type[1].nullable
209+
210+
211+
def test_nested_list_in_struct_preserves_nullability() -> None:
212+
"""Test that nested list fields in structs preserve nullability."""
213+
schema = create_schema(
214+
"test",
215+
{
216+
"a": dy.Struct(
217+
{"list_field": dy.List(dy.String(nullable=False), nullable=True)},
218+
nullable=True,
219+
)
220+
},
221+
)
222+
pa_schema = schema.to_pyarrow_schema()
223+
struct_field = pa_schema.field("a")
224+
list_type = struct_field.type[0].type
225+
assert not list_type.value_field.nullable
226+
227+
228+
def test_deeply_nested_nullability() -> None:
229+
schema = create_schema(
230+
"test",
231+
{
232+
"a": dy.Struct(
233+
{
234+
"nested": dy.Struct(
235+
{
236+
"required": dy.String(nullable=False),
237+
"optional": dy.String(nullable=True),
238+
},
239+
nullable=True,
240+
),
241+
},
242+
nullable=True,
243+
)
244+
},
245+
)
246+
pa_schema = schema.to_pyarrow_schema()
247+
outer_struct = pa_schema.field("a").type
248+
inner_struct = outer_struct[0].type
249+
assert not inner_struct[0].nullable # required field
250+
assert inner_struct[1].nullable # optional field
251+
252+
164253
def test_multiple_columns() -> None:
165254
schema = create_schema(
166255
"test", {"a": dy.Int32(nullable=False), "b": dy.Integer(nullable=True)}

0 commit comments

Comments
 (0)