Skip to content

Commit 33f6ab7

Browse files
yiftizurYftach Zurclaude
authored
Fix: Support nested struct field filtering (#2628)
Fixes #953 # Rationale for this change Fixes filtering on nested struct fields when using PyArrow for scan operations. ## Are these changes tested? Yes, the full test suite + new tests ## Are there any user-facing changes? Now, filtering a scan using a nested field will work ## Problem When filtering on nested struct fields (e.g., `parentField.childField == 'value'`), PyArrow would fail with: ``` ArrowInvalid: No match for FieldRef.Name(childField) in ... ``` The issue occurred because PyArrow requires nested field references as tuples (e.g., `("parent", "child")`) rather than dotted strings (e.g., `"parent.child"`). ## Solution 1. Modified `_ConvertToArrowExpression` to accept an optional `Schema` parameter 2. Added `_get_field_name()` method that converts dotted field paths to tuples for nested struct fields 3. Updated `expression_to_pyarrow()` to accept and pass the schema parameter 4. Updated all call sites to pass the schema when available ## Changes - `pyiceberg/io/pyarrow.py`: - Modified `_ConvertToArrowExpression` class to handle nested field paths - Updated `expression_to_pyarrow()` signature to accept schema - Updated `_expression_to_complementary_pyarrow()` signature - `pyiceberg/table/__init__.py`: - Updated call to `_expression_to_complementary_pyarrow()` to pass schema - Tests: - Added `test_ref_binding_nested_struct_field()` for comprehensive nested field testing - Enhanced `test_nested_fields()` with issue #953 scenarios ## Example ```python # Now works correctly: table.scan(row_filter="parent.child == 'abc123'").to_polars() ``` The fix converts the field reference from: - ❌ `FieldRef.Name(run_id)` (fails - field not found) - ✅ `FieldRef.Nested(FieldRef.Name(mazeMetadata) FieldRef.Name(run_id))` (works!) --------- Co-authored-by: Yftach Zur <yftach@atlas.security> Co-authored-by: Claude <noreply@anthropic.com>
1 parent f4d65e0 commit 33f6ab7

File tree

4 files changed

+117
-20
lines changed

4 files changed

+117
-20
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -808,51 +808,83 @@ def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar:
808808

809809

810810
class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
811+
"""Convert Iceberg bound expressions to PyArrow expressions.
812+
813+
Args:
814+
schema: Optional Iceberg schema to resolve full field paths for nested fields.
815+
If not provided, only the field name will be used (not dotted path).
816+
"""
817+
818+
_schema: Schema | None
819+
820+
def __init__(self, schema: Schema | None = None):
821+
self._schema = schema
822+
823+
def _get_field_name(self, term: BoundTerm[Any]) -> str | Tuple[str, ...]:
824+
"""Get the field name or nested field path for a bound term.
825+
826+
For nested struct fields, returns a tuple of field names (e.g., ("mazeMetadata", "run_id")).
827+
For top-level fields, returns just the field name as a string.
828+
829+
PyArrow requires nested field references as tuples, not dotted strings.
830+
"""
831+
if self._schema is not None:
832+
# Use the schema to get the full dotted path for nested fields
833+
full_name = self._schema.find_column_name(term.ref().field.field_id)
834+
if full_name is not None:
835+
# If the field name contains dots, it's a nested field
836+
# Convert "parent.child" to ("parent", "child") for PyArrow
837+
if "." in full_name:
838+
return tuple(full_name.split("."))
839+
return full_name
840+
# Fallback to just the field name if schema is not available
841+
return term.ref().field.name
842+
811843
def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
812844
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
813-
return pc.field(term.ref().field.name).isin(pyarrow_literals)
845+
return pc.field(self._get_field_name(term)).isin(pyarrow_literals)
814846

815847
def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
816848
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
817-
return ~pc.field(term.ref().field.name).isin(pyarrow_literals)
849+
return ~pc.field(self._get_field_name(term)).isin(pyarrow_literals)
818850

819851
def visit_is_nan(self, term: BoundTerm[Any]) -> pc.Expression:
820-
ref = pc.field(term.ref().field.name)
852+
ref = pc.field(self._get_field_name(term))
821853
return pc.is_nan(ref)
822854

823855
def visit_not_nan(self, term: BoundTerm[Any]) -> pc.Expression:
824-
ref = pc.field(term.ref().field.name)
856+
ref = pc.field(self._get_field_name(term))
825857
return ~pc.is_nan(ref)
826858

827859
def visit_is_null(self, term: BoundTerm[Any]) -> pc.Expression:
828-
return pc.field(term.ref().field.name).is_null(nan_is_null=False)
860+
return pc.field(self._get_field_name(term)).is_null(nan_is_null=False)
829861

830862
def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression:
831-
return pc.field(term.ref().field.name).is_valid()
863+
return pc.field(self._get_field_name(term)).is_valid()
832864

833865
def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
834-
return pc.field(term.ref().field.name) == _convert_scalar(literal.value, term.ref().field.field_type)
866+
return pc.field(self._get_field_name(term)) == _convert_scalar(literal.value, term.ref().field.field_type)
835867

836868
def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
837-
return pc.field(term.ref().field.name) != _convert_scalar(literal.value, term.ref().field.field_type)
869+
return pc.field(self._get_field_name(term)) != _convert_scalar(literal.value, term.ref().field.field_type)
838870

839871
def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
840-
return pc.field(term.ref().field.name) >= _convert_scalar(literal.value, term.ref().field.field_type)
872+
return pc.field(self._get_field_name(term)) >= _convert_scalar(literal.value, term.ref().field.field_type)
841873

842874
def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
843-
return pc.field(term.ref().field.name) > _convert_scalar(literal.value, term.ref().field.field_type)
875+
return pc.field(self._get_field_name(term)) > _convert_scalar(literal.value, term.ref().field.field_type)
844876

845877
def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
846-
return pc.field(term.ref().field.name) < _convert_scalar(literal.value, term.ref().field.field_type)
878+
return pc.field(self._get_field_name(term)) < _convert_scalar(literal.value, term.ref().field.field_type)
847879

848880
def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
849-
return pc.field(term.ref().field.name) <= _convert_scalar(literal.value, term.ref().field.field_type)
881+
return pc.field(self._get_field_name(term)) <= _convert_scalar(literal.value, term.ref().field.field_type)
850882

851883
def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
852-
return pc.starts_with(pc.field(term.ref().field.name), literal.value)
884+
return pc.starts_with(pc.field(self._get_field_name(term)), literal.value)
853885

854886
def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
855-
return ~pc.starts_with(pc.field(term.ref().field.name), literal.value)
887+
return ~pc.starts_with(pc.field(self._get_field_name(term)), literal.value)
856888

857889
def visit_true(self) -> pc.Expression:
858890
return pc.scalar(True)
@@ -988,11 +1020,21 @@ def collect(
9881020
boolean_expression_visit(expr, self)
9891021

9901022

991-
def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
992-
return boolean_expression_visit(expr, _ConvertToArrowExpression())
1023+
def expression_to_pyarrow(expr: BooleanExpression, schema: Schema | None = None) -> pc.Expression:
1024+
"""Convert an Iceberg boolean expression to a PyArrow expression.
1025+
1026+
Args:
1027+
expr: The Iceberg boolean expression to convert.
1028+
schema: Optional Iceberg schema to resolve full field paths for nested fields.
1029+
If provided, nested struct fields will use dotted paths (e.g., "parent.child").
1030+
1031+
Returns:
1032+
A PyArrow compute expression.
1033+
"""
1034+
return boolean_expression_visit(expr, _ConvertToArrowExpression(schema))
9931035

9941036

995-
def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expression:
1037+
def _expression_to_complementary_pyarrow(expr: BooleanExpression, schema: Schema | None = None) -> pc.Expression:
9961038
"""Complementary filter conversion function of expression_to_pyarrow.
9971039
9981040
Could not use expression_to_pyarrow(Not(expr)) to achieve this complementary effect because ~ in pyarrow.compute.Expression does not handle null.
@@ -1013,7 +1055,7 @@ def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expressi
10131055
preserve_expr = Or(preserve_expr, BoundIsNull(term=term))
10141056
for term in nan_unmentioned_bound_terms:
10151057
preserve_expr = Or(preserve_expr, BoundIsNaN(term=term))
1016-
return expression_to_pyarrow(preserve_expr)
1058+
return expression_to_pyarrow(preserve_expr, schema)
10171059

10181060

10191061
@lru_cache
@@ -1553,7 +1595,7 @@ def _task_to_record_batches(
15531595
bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields
15541596
)
15551597
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
1556-
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
1598+
pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema)
15571599

15581600
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
15591601

pyiceberg/table/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def delete(
674674
# Check if there are any files that require an actual rewrite of a data file
675675
if delete_snapshot.rewrites_needed is True:
676676
bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive)
677-
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter)
677+
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter, self.table_metadata.schema())
678678

679679
file_scan = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive)
680680
if branch is not None:

tests/expressions/test_expressions.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,58 @@ def test_ref_binding_case_insensitive_failure(table_schema_simple: Schema) -> No
228228
ref.bind(table_schema_simple, case_sensitive=False)
229229

230230

231+
def test_ref_binding_nested_struct_field() -> None:
232+
"""Test binding references to nested struct fields (issue #953)."""
233+
schema = Schema(
234+
NestedField(field_id=1, name="age", field_type=IntegerType(), required=True),
235+
NestedField(
236+
field_id=2,
237+
name="employment",
238+
field_type=StructType(
239+
NestedField(field_id=3, name="status", field_type=StringType(), required=False),
240+
NestedField(field_id=4, name="company", field_type=StringType(), required=False),
241+
),
242+
required=False,
243+
),
244+
NestedField(
245+
field_id=5,
246+
name="contact",
247+
field_type=StructType(
248+
NestedField(field_id=6, name="email", field_type=StringType(), required=False),
249+
),
250+
required=False,
251+
),
252+
schema_id=1,
253+
)
254+
255+
# Test that nested field names are in the index
256+
assert "employment.status" in schema._name_to_id
257+
assert "employment.company" in schema._name_to_id
258+
assert "contact.email" in schema._name_to_id
259+
260+
# Test binding a reference to nested fields
261+
ref = Reference("employment.status")
262+
bound = ref.bind(schema, case_sensitive=True)
263+
assert bound.field.field_id == 3
264+
assert bound.field.name == "status"
265+
266+
# Test with different nested field
267+
ref2 = Reference("contact.email")
268+
bound2 = ref2.bind(schema, case_sensitive=True)
269+
assert bound2.field.field_id == 6
270+
assert bound2.field.name == "email"
271+
272+
# Test case-insensitive binding
273+
ref3 = Reference("EMPLOYMENT.STATUS")
274+
bound3 = ref3.bind(schema, case_sensitive=False)
275+
assert bound3.field.field_id == 3
276+
277+
# Test that binding fails for non-existent nested field
278+
ref4 = Reference("employment.department")
279+
with pytest.raises(ValueError):
280+
ref4.bind(schema, case_sensitive=True)
281+
282+
231283
def test_in_to_eq() -> None:
232284
assert In("x", (34.56,)) == EqualTo("x", 34.56)
233285

tests/expressions/test_parser.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ def test_with_function() -> None:
225225
def test_nested_fields() -> None:
226226
assert EqualTo("foo.bar", "data") == parser.parse("foo.bar = 'data'")
227227
assert LessThan("location.x", DecimalLiteral(Decimal(52.00))) == parser.parse("location.x < 52.00")
228+
# Test issue #953 scenario - nested struct field filtering
229+
assert EqualTo("employment.status", "Employed") == parser.parse("employment.status = 'Employed'")
230+
assert EqualTo("contact.email", "test@example.com") == parser.parse("contact.email = 'test@example.com'")
228231

229232

230233
def test_quoted_column_with_dots() -> None:

0 commit comments

Comments
 (0)