diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index 08fa529087ff..3bf36c1d4eee 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -364,6 +364,9 @@ def test_unwrap_udt(self): ] self.assertEqual(results, expected) + def test_hashable(self): + _ = hash(VectorUDT()) + class MatrixUDTTests(MLlibTestCase): dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) @@ -394,6 +397,9 @@ def test_infer_schema(self): else: raise ValueError("Expected a matrix but got type %r" % type(m)) + def test_hashable(self): + _ = hash(MatrixUDT()) + if __name__ == "__main__": from pyspark.ml.tests.test_linalg import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 0a5219202a3a..5f5314a03c32 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -2076,6 +2076,44 @@ def test_repr(self): for instance in instances: self.assertEqual(eval(repr(instance)), instance) + def test_hashable(self): + for dt in [ + NullType(), + StringType(), + StringType("UTF8_BINARY"), + StringType("UTF8_LCASE"), + StringType("UNICODE"), + StringType("UNICODE_CI"), + CharType(10), + VarcharType(10), + BinaryType(), + BooleanType(), + DateType(), + TimeType(), + TimestampType(), + TimestampNTZType(), + TimeType(), + DecimalType(), + DoubleType(), + FloatType(), + ByteType(), + IntegerType(), + LongType(), + ShortType(), + DayTimeIntervalType(), + YearMonthIntervalType(), + CalendarIntervalType(), + ArrayType(StringType()), + MapType(StringType(), IntegerType()), + StructField("f1", StringType(), True), + StructType([StructField("f1", StringType(), True)]), + VariantType(), + GeometryType(0), + GeographyType(4326), + ExamplePointUDT(), + ]: + _ = hash(dt) + def test_daytime_interval_type_constructor(self): # SPARK-37277: Test constructors in day time interval. self.assertEqual(DayTimeIntervalType().simpleString(), "interval day to second") diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 95307ea3859c..72ac58aa8298 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1992,9 +1992,6 @@ def fromJson(cls, json: Dict[str, Any]) -> "UserDefinedType": UDT = getattr(m, pyClass) return UDT() - def __eq__(self, other: Any) -> bool: - return type(self) == type(other) - class VariantVal: """