From 9380e91f8bbdc5b7d6855815ea352070d1e248fb Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 18 Nov 2025 15:41:39 +0800 Subject: [PATCH 1/3] fix --- python/pyspark/ml/tests/test_linalg.py | 6 +++++ python/pyspark/sql/tests/test_types.py | 32 ++++++++++++++++++++++++++ python/pyspark/sql/types.py | 4 ++++ 3 files changed, 42 insertions(+) diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index 08fa529087ff..3bf36c1d4eee 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -364,6 +364,9 @@ def test_unwrap_udt(self): ] self.assertEqual(results, expected) + def test_hashable(self): + _ = hash(VectorUDT()) + class MatrixUDTTests(MLlibTestCase): dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) @@ -394,6 +397,9 @@ def test_infer_schema(self): else: raise ValueError("Expected a matrix but got type %r" % type(m)) + def test_hashable(self): + _ = hash(MatrixUDT()) + if __name__ == "__main__": from pyspark.ml.tests.test_linalg import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 0a5219202a3a..2c7691d9849e 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -2076,6 +2076,38 @@ def test_repr(self): for instance in instances: self.assertEqual(eval(repr(instance)), instance) + def test_hashable(self): + for dt in [ + NullType(), + StringType(), + StringType("UTF8_BINARY"), + StringType("UTF8_LCASE"), + StringType("UNICODE"), + StringType("UNICODE_CI"), + CharType(10), + VarcharType(10), + BinaryType(), + BooleanType(), + DateType(), + TimeType(), + TimestampType(), + DecimalType(), + DoubleType(), + FloatType(), + ByteType(), + IntegerType(), + LongType(), + ShortType(), + CalendarIntervalType(), + ArrayType(StringType()), + MapType(StringType(), IntegerType()), + StructField("f1", StringType(), True), + StructType([StructField("f1", StringType(), True)]), + VariantType(), + ExamplePointUDT(), + ]: + _ = hash(dt) + def test_daytime_interval_type_constructor(self): # SPARK-37277: Test constructors in day time interval. self.assertEqual(DayTimeIntervalType().simpleString(), "interval day to second") diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 95307ea3859c..caff84335bbb 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1995,6 +1995,10 @@ def fromJson(cls, json: Dict[str, Any]) -> "UserDefinedType": def __eq__(self, other: Any) -> bool: return type(self) == type(other) + # __hash__ should be defined together with __eq__, otherwise it is not hashable + def __hash__(self) -> int: + return hash(str(self)) + class VariantVal: """ From 96edab4c4d4f593b68a9a5629f88a0f0b1588e38 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 19 Nov 2025 09:43:51 +0800 Subject: [PATCH 2/3] del __eq__ --- python/pyspark/sql/types.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index caff84335bbb..72ac58aa8298 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1992,13 +1992,6 @@ def fromJson(cls, json: Dict[str, Any]) -> "UserDefinedType": UDT = getattr(m, pyClass) return UDT() - def __eq__(self, other: Any) -> bool: - return type(self) == type(other) - - # __hash__ should be defined together with __eq__, otherwise it is not hashable - def __hash__(self) -> int: - return hash(str(self)) - class VariantVal: """ From 4277671da0d75bbd5d5c978c5663e689ec4f5338 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 19 Nov 2025 10:46:13 +0800 Subject: [PATCH 3/3] fix --- python/pyspark/sql/tests/test_types.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 2c7691d9849e..5f5314a03c32 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -2091,6 +2091,8 @@ def test_hashable(self): DateType(), TimeType(), TimestampType(), + TimestampNTZType(), + TimeType(), DecimalType(), DoubleType(), FloatType(), @@ -2098,12 +2100,16 @@ def test_hashable(self): IntegerType(), LongType(), ShortType(), + DayTimeIntervalType(), + YearMonthIntervalType(), CalendarIntervalType(), ArrayType(StringType()), MapType(StringType(), IntegerType()), StructField("f1", StringType(), True), StructType([StructField("f1", StringType(), True)]), VariantType(), + GeometryType(0), + GeographyType(4326), ExamplePointUDT(), ]: _ = hash(dt)