Skip to content

Commit 31ed623

Browse files
committed
fix: upsert with null values in join columns
1 parent 52d810e commit 31ed623

File tree

2 files changed

+76
-13
lines changed

2 files changed

+76
-13
lines changed

pyiceberg/table/upsert_util.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,61 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
import functools
18-
import operator
17+
from math import isnan
18+
from typing import Any
1919

2020
import pyarrow as pa
2121
from pyarrow import Table as pyarrow_table
2222
from pyarrow import compute as pc
2323

2424
from pyiceberg.expressions import (
2525
AlwaysFalse,
26+
And,
2627
BooleanExpression,
2728
EqualTo,
2829
In,
30+
IsNaN,
31+
IsNull,
2932
Or,
3033
)
3134

3235

3336
def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
3437
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
38+
filters = []
3539

3640
if len(join_cols) == 1:
37-
return In(join_cols[0], unique_keys[0].to_pylist())
41+
column = join_cols[0]
42+
values = set(unique_keys[0].to_pylist())
43+
44+
if None in values:
45+
filters.append(IsNull(column))
46+
values.remove(None)
47+
48+
if nans := {v for v in values if isinstance(v, float) and isnan(v)}:
49+
filters.append(IsNaN(column))
50+
values -= nans
51+
52+
filters.append(In(column, values))
53+
else:
54+
55+
def equals(column: str, value: Any) -> BooleanExpression:
56+
if value is None:
57+
return IsNull(column)
58+
59+
if isinstance(value, float) and isnan(value):
60+
return IsNaN(column)
61+
62+
return EqualTo(column, value)
63+
64+
filters = [And(*[equals(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()]
65+
66+
if len(filters) == 0:
67+
return AlwaysFalse()
68+
elif len(filters) == 1:
69+
return filters[0]
3870
else:
39-
filters = [
40-
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
41-
]
42-
43-
if len(filters) == 0:
44-
return AlwaysFalse()
45-
elif len(filters) == 1:
46-
return filters[0]
47-
else:
48-
return Or(*filters)
71+
return Or(*filters)
4972

5073

5174
def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:

tests/table/test_upsert.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,46 @@ def test_upsert_with_nulls(catalog: Catalog) -> None:
710710
schema=schema,
711711
)
712712

713+
# upsert table with null value
714+
data_with_null = pa.Table.from_pylist(
715+
[
716+
{"foo": None, "bar": 1, "baz": False},
717+
],
718+
schema=schema,
719+
)
720+
upd = table.upsert(data_with_null, join_cols=["foo"])
721+
assert upd.rows_updated == 0
722+
assert upd.rows_inserted == 1
723+
assert table.scan().to_arrow() == pa.Table.from_pylist(
724+
[
725+
{"foo": None, "bar": 1, "baz": False},
726+
{"foo": "apple", "bar": 7, "baz": False},
727+
{"foo": "banana", "bar": None, "baz": False},
728+
],
729+
schema=schema,
730+
)
731+
732+
# upsert table with null and non-null values, in two join columns
733+
data_with_null = pa.Table.from_pylist(
734+
[
735+
{"foo": None, "bar": 1, "baz": True},
736+
{"foo": "lemon", "bar": None, "baz": False},
737+
],
738+
schema=schema,
739+
)
740+
upd = table.upsert(data_with_null, join_cols=["foo", "bar"])
741+
assert upd.rows_updated == 1
742+
assert upd.rows_inserted == 1
743+
assert table.scan().to_arrow() == pa.Table.from_pylist(
744+
[
745+
{"foo": "lemon", "bar": None, "baz": False},
746+
{"foo": None, "bar": 1, "baz": True},
747+
{"foo": "apple", "bar": 7, "baz": False},
748+
{"foo": "banana", "bar": None, "baz": False},
749+
],
750+
schema=schema,
751+
)
752+
713753

714754
def test_transaction(catalog: Catalog) -> None:
715755
"""Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is

0 commit comments

Comments
 (0)