Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/pyspark/ml/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,9 @@ def test_unwrap_udt(self):
]
self.assertEqual(results, expected)

def test_hashable(self):
_ = hash(VectorUDT())


class MatrixUDTTests(MLlibTestCase):
dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
Expand Down Expand Up @@ -394,6 +397,9 @@ def test_infer_schema(self):
else:
raise ValueError("Expected a matrix but got type %r" % type(m))

def test_hashable(self):
_ = hash(MatrixUDT())


if __name__ == "__main__":
from pyspark.ml.tests.test_linalg import * # noqa: F401
Expand Down
32 changes: 32 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,6 +2076,38 @@ def test_repr(self):
for instance in instances:
self.assertEqual(eval(repr(instance)), instance)

def test_hashable(self):
for dt in [
NullType(),
StringType(),
StringType("UTF8_BINARY"),
StringType("UTF8_LCASE"),
StringType("UNICODE"),
StringType("UNICODE_CI"),
CharType(10),
VarcharType(10),
BinaryType(),
BooleanType(),
DateType(),
TimeType(),
TimestampType(),
DecimalType(),
DoubleType(),
FloatType(),
ByteType(),
IntegerType(),
LongType(),
ShortType(),
CalendarIntervalType(),
ArrayType(StringType()),
MapType(StringType(), IntegerType()),
StructField("f1", StringType(), True),
StructType([StructField("f1", StringType(), True)]),
VariantType(),
ExamplePointUDT(),
]:
_ = hash(dt)

def test_daytime_interval_type_constructor(self):
# SPARK-37277: Test constructors in day time interval.
self.assertEqual(DayTimeIntervalType().simpleString(), "interval day to second")
Expand Down
3 changes: 0 additions & 3 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,9 +1992,6 @@ def fromJson(cls, json: Dict[str, Any]) -> "UserDefinedType":
UDT = getattr(m, pyClass)
return UDT()

def __eq__(self, other: Any) -> bool:
return type(self) == type(other)


class VariantVal:
"""
Expand Down