Skip to content

Commit ffffb3f

Browse files
zhengruifenghuangxiaopingRD
authored andcommitted
[SPARK-54397][PYTHON] Make UserDefinedType hashable
### What changes were proposed in this pull request? Fix the hashability of `UserDefinedType` ### Why are the changes needed? UDT is not hashable, e.g. ``` In [11]: from pyspark.testing.objects import ExamplePointUDT In [12]: e = ExamplePointUDT() In [13]: {e: 0} --------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[13], line 1 ----> 1 {e: 0} TypeError: unhashable type: 'ExamplePointUDT' In [14]: from pyspark.ml.linalg import VectorUDT In [15]: v = VectorUDT() In [16]: {v: 1} --------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[16], line 1 ----> 1 {v: 1} TypeError: unhashable type: 'VectorUDT' ``` see https://docs.python.org/3/reference/datamodel.html#object.__hash__ > If a class does not define an __eq__() method it should not define a __hash__() operation either; if it defines __eq__() but not __hash__(), its instances will not be usable as items in hashable collections. > A class that overrides __eq__() and does not define __hash__() will have its __hash__() implicitly set to None. When the __hash__() method of a class is None, instances of the class will raise an appropriate TypeError when a program attempts to retrieve their hash value, and will also be correctly identified as unhashable when checking isinstance(obj, collections.abc.Hashable). ### Does this PR introduce _any_ user-facing change? yes, `hash(udt)` will work after this fix ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#53113 from zhengruifeng/type_hashable. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 7609a10 commit ffffb3f

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

python/pyspark/ml/tests/test_linalg.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ def test_unwrap_udt(self):
364364
]
365365
self.assertEqual(results, expected)
366366

367+
def test_hashable(self):
368+
_ = hash(VectorUDT())
369+
367370

368371
class MatrixUDTTests(MLlibTestCase):
369372
dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
@@ -394,6 +397,9 @@ def test_infer_schema(self):
394397
else:
395398
raise ValueError("Expected a matrix but got type %r" % type(m))
396399

400+
def test_hashable(self):
401+
_ = hash(MatrixUDT())
402+
397403

398404
if __name__ == "__main__":
399405
from pyspark.ml.tests.test_linalg import * # noqa: F401

python/pyspark/sql/tests/test_types.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,6 +2076,44 @@ def test_repr(self):
20762076
for instance in instances:
20772077
self.assertEqual(eval(repr(instance)), instance)
20782078

2079+
def test_hashable(self):
2080+
for dt in [
2081+
NullType(),
2082+
StringType(),
2083+
StringType("UTF8_BINARY"),
2084+
StringType("UTF8_LCASE"),
2085+
StringType("UNICODE"),
2086+
StringType("UNICODE_CI"),
2087+
CharType(10),
2088+
VarcharType(10),
2089+
BinaryType(),
2090+
BooleanType(),
2091+
DateType(),
2092+
TimeType(),
2093+
TimestampType(),
2094+
TimestampNTZType(),
2095+
TimeType(),
2096+
DecimalType(),
2097+
DoubleType(),
2098+
FloatType(),
2099+
ByteType(),
2100+
IntegerType(),
2101+
LongType(),
2102+
ShortType(),
2103+
DayTimeIntervalType(),
2104+
YearMonthIntervalType(),
2105+
CalendarIntervalType(),
2106+
ArrayType(StringType()),
2107+
MapType(StringType(), IntegerType()),
2108+
StructField("f1", StringType(), True),
2109+
StructType([StructField("f1", StringType(), True)]),
2110+
VariantType(),
2111+
GeometryType(0),
2112+
GeographyType(4326),
2113+
ExamplePointUDT(),
2114+
]:
2115+
_ = hash(dt)
2116+
20792117
def test_daytime_interval_type_constructor(self):
20802118
# SPARK-37277: Test constructors in day time interval.
20812119
self.assertEqual(DayTimeIntervalType().simpleString(), "interval day to second")

python/pyspark/sql/types.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,9 +2003,6 @@ def fromJson(cls, json: Dict[str, Any]) -> "UserDefinedType":
20032003
UDT = getattr(m, pyClass)
20042004
return UDT()
20052005

2006-
def __eq__(self, other: Any) -> bool:
2007-
return type(self) == type(other)
2008-
20092006

20102007
class VariantVal:
20112008
"""

0 commit comments

Comments
 (0)