diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index db185d47c..32a750290 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -144,6 +144,18 @@ class DeviceHardware: storage: Storage perf: Perf + def __hash__(self) -> int: + return hash((self.rank, self.storage, self.perf)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DeviceHardware): + return False + return ( + self.rank == other.rank + and self.storage == other.storage + and self.perf == other.perf + ) + class CustomTopologyData: """ @@ -247,6 +259,17 @@ def get_bw( else: return self.inter_host_bw + def __hash__(self) -> int: + return hash((self._inter_host_bw, self._intra_host_bw)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, BasicCommsBandwidths): + return False + return ( + self._inter_host_bw == other._inter_host_bw + and self._intra_host_bw == other._intra_host_bw + ) + class Topology: """