From 1321c1b64ab95fb3b357e6a4bdadb5a7296606fb Mon Sep 17 00:00:00 2001 From: Matteo De Wint Date: Thu, 4 Sep 2025 22:04:40 +0200 Subject: [PATCH 1/7] fix: upsert with null values in join columns --- pyiceberg/table/upsert_util.py | 49 +++++++++++++++++++++++++--------- tests/table/test_upsert.py | 40 +++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 13 deletions(-) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index cefdd101a0..9f2139fdf7 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import functools -import operator +from math import isnan +from typing import Any import pyarrow as pa from pyarrow import Table as pyarrow_table @@ -23,29 +23,52 @@ from pyiceberg.expressions import ( AlwaysFalse, + And, BooleanExpression, EqualTo, In, + IsNaN, + IsNull, Or, ) def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) + filters = [] if len(join_cols) == 1: - return In(join_cols[0], unique_keys[0].to_pylist()) + column = join_cols[0] + values = set(unique_keys[0].to_pylist()) + + if None in values: + filters.append(IsNull(column)) + values.remove(None) + + if nans := {v for v in values if isinstance(v, float) and isnan(v)}: + filters.append(IsNaN(column)) + values -= nans + + filters.append(In(column, values)) + else: + + def equals(column: str, value: Any) -> BooleanExpression: + if value is None: + return IsNull(column) + + if isinstance(value, float) and isnan(value): + return IsNaN(column) + + return EqualTo(column, value) + + filters = [And(*[equals(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()] + + if len(filters) == 0: + return AlwaysFalse() + elif len(filters) == 1: + return filters[0] else: - filters = [ - functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist() - ] - - if len(filters) == 0: - return AlwaysFalse() - elif len(filters) == 1: - return filters[0] - else: - return Or(*filters) + return Or(*filters) def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 891d4bbac7..4f5a5fd25b 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -710,6 +710,46 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: schema=schema, ) + # upsert table with null value + data_with_null = pa.Table.from_pylist( + [ + {"foo": None, "bar": 1, "baz": False}, + ], + schema=schema, + ) + upd = table.upsert(data_with_null, join_cols=["foo"]) + assert upd.rows_updated == 0 + assert upd.rows_inserted == 1 + assert table.scan().to_arrow() == pa.Table.from_pylist( + [ + {"foo": None, "bar": 1, "baz": False}, + {"foo": "apple", "bar": 7, "baz": False}, + {"foo": "banana", "bar": None, "baz": False}, + ], + schema=schema, + ) + + # upsert table with null and non-null values, in two join columns + data_with_null = pa.Table.from_pylist( + [ + {"foo": None, "bar": 1, "baz": True}, + {"foo": "lemon", "bar": None, "baz": False}, + ], + schema=schema, + ) + upd = table.upsert(data_with_null, join_cols=["foo", "bar"]) + assert upd.rows_updated == 1 + assert upd.rows_inserted == 1 + assert table.scan().to_arrow() == pa.Table.from_pylist( + [ + {"foo": "lemon", "bar": None, "baz": False}, + {"foo": None, "bar": 1, "baz": True}, + {"foo": "apple", "bar": 7, "baz": False}, + {"foo": "banana", "bar": None, "baz": False}, + ], + schema=schema, + ) + def test_transaction(catalog: Catalog) -> None: """Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is From 4df175f42a1c1218271d1c91a9344824ca07d194 Mon Sep 17 00:00:00 2001 From: Matteo De Wint Date: Fri, 5 Sep 2025 09:08:07 +0200 Subject: [PATCH 2/7] test: add test cases for create_match_filter --- tests/table/test_upsert.py | 68 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 4f5a5fd25b..53e595734d 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -23,8 +23,8 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference -from pyiceberg.expressions.literals import LongLiteral +from pyiceberg.expressions import AlwaysTrue, And, BooleanExpression, EqualTo, In, IsNaN, IsNull, Or, Reference +from pyiceberg.expressions.literals import DoubleLiteral, LongLiteral from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema from pyiceberg.table import UpsertResult @@ -440,6 +440,70 @@ def test_create_match_filter_single_condition() -> None: ) +@pytest.mark.parametrize( + "data, expected", + [ + pytest.param( + [{"x": 1.0}, {"x": 2.0}, {"x": None}, {"x": 4.0}, {"x": float("nan")}], + Or( + left=IsNull(term=Reference(name="x")), + right=Or( + left=IsNaN(term=Reference(name="x")), + right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}), + ), + ), + id="single-column", + ), + pytest.param( + [ + {"x": 1.0, "y": 9.0}, + {"x": 2.0, "y": None}, + {"x": None, "y": 7.0}, + {"x": 4.0, "y": float("nan")}, + {"x": float("nan"), "y": 0.0}, + ], + Or( + left=Or( + left=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)), + ), + right=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)), + right=IsNull(term=Reference(name="y")), + ), + ), + right=Or( + left=And( + left=IsNull(term=Reference(name="x")), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)), + ), + right=Or( + left=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)), + right=IsNaN(term=Reference(name="y")), + ), + right=And( + left=IsNaN(term=Reference(name="x")), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)), + ), + ), + ), + ), + id="multi-column", + ), + ], +) +def test_create_match_filter_with_nulls(data: list[dict], expected: BooleanExpression) -> None: + schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())]) + table = pa.Table.from_pylist(data, schema=schema) + join_cols = sorted({col for record in data for col in record}) + + expr = create_match_filter(table, join_cols) + + assert expr == expected + + def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None: identifier = "default.test_upsert_with_duplicate_rows_in_table" From 763042f74277cbdb74049e116ab7186303bd932e Mon Sep 17 00:00:00 2001 From: Matteo De Wint Date: Fri, 5 Sep 2025 11:33:46 +0200 Subject: [PATCH 3/7] fix: respect null values in inner join in get_rows_to_update --- pyiceberg/table/upsert_util.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 9f2139fdf7..86b38321a4 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -121,13 +121,16 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table)))) # Step 3: Perform an inner join to find which rows from source exist in target - matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") + # PyArrow joins ignore null values, and we want null==null to hold, so we compute the join in Python. + # This is equivalent to: + # matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") + source_indices = {tuple(row[col] for col in join_cols): row[SOURCE_INDEX_COLUMN_NAME] for row in source_index.to_pylist()} + target_indices = {tuple(row[col] for col in join_cols): row[TARGET_INDEX_COLUMN_NAME] for row in target_index.to_pylist()} + matching_indices = [(s, t) for key, s in source_indices.items() if (t := target_indices.get(key)) is not None] # Step 4: Compare all rows using Python to_update_indices = [] - for source_idx, target_idx in zip( - matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist() - ): + for source_idx, target_idx in matching_indices: source_row = source_table.slice(source_idx, 1) target_row = target_table.slice(target_idx, 1) From a32bb06c4e82c75dd68fde4c036b9c6c09ac538f Mon Sep 17 00:00:00 2001 From: Matteo De Wint Date: Fri, 5 Sep 2025 11:59:39 +0200 Subject: [PATCH 4/7] fix: type hints --- pyiceberg/table/upsert_util.py | 2 +- tests/table/test_upsert.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 86b38321a4..918e47b806 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -35,7 +35,7 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: unique_keys = df.select(join_cols).group_by(join_cols).aggregate([]) - filters = [] + filters: list[BooleanExpression] = [] if len(join_cols) == 1: column = join_cols[0] diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 53e595734d..9762b8b1d7 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from pathlib import PosixPath +from typing import Any import pyarrow as pa import pytest @@ -494,7 +495,7 @@ def test_create_match_filter_single_condition() -> None: ), ], ) -def test_create_match_filter_with_nulls(data: list[dict], expected: BooleanExpression) -> None: +def test_create_match_filter_with_nulls(data: list[dict[str, Any]], expected: BooleanExpression) -> None: schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())]) table = pa.Table.from_pylist(data, schema=schema) join_cols = sorted({col for record in data for col in record}) From 9af1b3d26f5367c10f184673f18b173890c0c9e9 Mon Sep 17 00:00:00 2001 From: Matteo De Wint Date: Wed, 24 Sep 2025 11:36:33 +0200 Subject: [PATCH 5/7] test: add test case for create_match_filter without null --- tests/table/test_upsert.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 9762b8b1d7..1ef6f65e83 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -444,6 +444,11 @@ def test_create_match_filter_single_condition() -> None: @pytest.mark.parametrize( "data, expected", [ + pytest.param( + [{"x": 1.0}, {"x": 2.0}, {"x": 3.0}], + In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(3.0)}), + id="single-column-without-null", + ), pytest.param( [{"x": 1.0}, {"x": 2.0}, {"x": None}, {"x": 4.0}, {"x": float("nan")}], Or( @@ -453,7 +458,7 @@ def test_create_match_filter_single_condition() -> None: right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}), ), ), - id="single-column", + id="single-column-with-null", ), pytest.param( [ @@ -491,11 +496,11 @@ def test_create_match_filter_single_condition() -> None: ), ), ), - id="multi-column", + id="multi-column-with-null", ), ], ) -def test_create_match_filter_with_nulls(data: list[dict[str, Any]], expected: BooleanExpression) -> None: +def test_create_match_filter(data: list[dict[str, Any]], expected: BooleanExpression) -> None: schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())]) table = pa.Table.from_pylist(data, schema=schema) join_cols = sorted({col for record in data for col in record}) From 6d772b938b3091f57a1fda99ff35bc3a6b5c9f6d Mon Sep 17 00:00:00 2001 From: Matteo De Wint Date: Wed, 24 Sep 2025 11:45:37 +0200 Subject: [PATCH 6/7] test: separate test_upsert_with_nulls_in_join_columns --- tests/table/test_upsert.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 1ef6f65e83..0dcd7c907f 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -780,6 +780,20 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: schema=schema, ) + +def test_upsert_with_nulls_in_join_columns(catalog: Catalog) -> None: + identifier = "default.test_upsert_with_nulls_in_join_columns" + _drop_table(catalog, identifier) + + schema = pa.schema( + [ + ("foo", pa.string()), + ("bar", pa.int32()), + ("baz", pa.bool_()), + ] + ) + table = catalog.create_table(identifier, schema) + # upsert table with null value data_with_null = pa.Table.from_pylist( [ @@ -793,8 +807,6 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: assert table.scan().to_arrow() == pa.Table.from_pylist( [ {"foo": None, "bar": 1, "baz": False}, - {"foo": "apple", "bar": 7, "baz": False}, - {"foo": "banana", "bar": None, "baz": False}, ], schema=schema, ) @@ -814,8 +826,6 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: [ {"foo": "lemon", "bar": None, "baz": False}, {"foo": None, "bar": 1, "baz": True}, - {"foo": "apple", "bar": 7, "baz": False}, - {"foo": "banana", "bar": None, "baz": False}, ], schema=schema, ) From cf3d68e18e2616c8108a3deedf2542091b8c625f Mon Sep 17 00:00:00 2001 From: Matteo De Wint Date: Wed, 15 Oct 2025 10:30:35 +0200 Subject: [PATCH 7/7] test: unroll parametrized tests for clarity --- tests/table/test_upsert.py | 130 +++++++++++++++++++------------------ 1 file changed, 68 insertions(+), 62 deletions(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 0dcd7c907f..334c026233 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. from pathlib import PosixPath -from typing import Any import pyarrow as pa import pytest @@ -24,7 +23,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import AlwaysTrue, And, BooleanExpression, EqualTo, In, IsNaN, IsNull, Or, Reference +from pyiceberg.expressions import AlwaysTrue, And, EqualTo, In, IsNaN, IsNull, Or, Reference from pyiceberg.expressions.literals import DoubleLiteral, LongLiteral from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema @@ -441,73 +440,80 @@ def test_create_match_filter_single_condition() -> None: ) -@pytest.mark.parametrize( - "data, expected", - [ - pytest.param( - [{"x": 1.0}, {"x": 2.0}, {"x": 3.0}], - In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(3.0)}), - id="single-column-without-null", +def test_create_match_filter_single_column_without_null() -> None: + data = [{"x": 1.0}, {"x": 2.0}, {"x": 3.0}] + + schema = pa.schema([pa.field("x", pa.float64())]) + table = pa.Table.from_pylist(data, schema=schema) + + expr = create_match_filter(table, join_cols=["x"]) + + assert expr == In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(3.0)}) + + +def test_create_match_filter_single_column_with_null() -> None: + data = [ + {"x": 1.0}, + {"x": 2.0}, + {"x": None}, + {"x": 4.0}, + {"x": float("nan")}, + ] + schema = pa.schema([pa.field("x", pa.float64())]) + table = pa.Table.from_pylist(data, schema=schema) + + expr = create_match_filter(table, join_cols=["x"]) + + assert expr == Or( + left=IsNull(term=Reference(name="x")), + right=Or( + left=IsNaN(term=Reference(name="x")), + right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}), ), - pytest.param( - [{"x": 1.0}, {"x": 2.0}, {"x": None}, {"x": 4.0}, {"x": float("nan")}], - Or( - left=IsNull(term=Reference(name="x")), - right=Or( - left=IsNaN(term=Reference(name="x")), - right=In(Reference(name="x"), {DoubleLiteral(1.0), DoubleLiteral(2.0), DoubleLiteral(4.0)}), - ), + ) + + +def test_create_match_filter_multi_column_with_null() -> None: + data = [ + {"x": 1.0, "y": 9.0}, + {"x": 2.0, "y": None}, + {"x": None, "y": 7.0}, + {"x": 4.0, "y": float("nan")}, + {"x": float("nan"), "y": 0.0}, + ] + schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())]) + table = pa.Table.from_pylist(data, schema=schema) + + expr = create_match_filter(table, join_cols=["x", "y"]) + + assert expr == Or( + left=Or( + left=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)), + ), + right=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)), + right=IsNull(term=Reference(name="y")), ), - id="single-column-with-null", ), - pytest.param( - [ - {"x": 1.0, "y": 9.0}, - {"x": 2.0, "y": None}, - {"x": None, "y": 7.0}, - {"x": 4.0, "y": float("nan")}, - {"x": float("nan"), "y": 0.0}, - ], - Or( - left=Or( - left=And( - left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(1.0)), - right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(9.0)), - ), - right=And( - left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(2.0)), - right=IsNull(term=Reference(name="y")), - ), + right=Or( + left=And( + left=IsNull(term=Reference(name="x")), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)), + ), + right=Or( + left=And( + left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)), + right=IsNaN(term=Reference(name="y")), ), - right=Or( - left=And( - left=IsNull(term=Reference(name="x")), - right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(7.0)), - ), - right=Or( - left=And( - left=EqualTo(term=Reference(name="x"), literal=DoubleLiteral(4.0)), - right=IsNaN(term=Reference(name="y")), - ), - right=And( - left=IsNaN(term=Reference(name="x")), - right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)), - ), - ), + right=And( + left=IsNaN(term=Reference(name="x")), + right=EqualTo(term=Reference(name="y"), literal=DoubleLiteral(0.0)), ), ), - id="multi-column-with-null", ), - ], -) -def test_create_match_filter(data: list[dict[str, Any]], expected: BooleanExpression) -> None: - schema = pa.schema([pa.field("x", pa.float64()), pa.field("y", pa.float64())]) - table = pa.Table.from_pylist(data, schema=schema) - join_cols = sorted({col for record in data for col in record}) - - expr = create_match_filter(table, join_cols) - - assert expr == expected + ) def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None: