Skip to content

Commit df967f6

Browse files
gregmacnamarafacebook-github-bot
authored andcommitted
Add hash and equality functions to planner types (#3522)
Summary: This diff adds explicit `__hash__` and `__eq__` methods to custom objects in the torchrec planner that previously relied on Python's default object identity-based hashing. 1. **DeviceHardware** (fbcode/torchrec/distributed/planner/types.py): - Hash based on rank, storage, and perf fields - Equality checks all three fields 2. **BasicCommsBandwidths** (fbcode/torchrec/distributed/planner/types.py): - Hash based on inter_host_bw and intra_host_bw - Equality checks both bandwidth values Reviewed By: iamzainhuda Differential Revision: D85731424
1 parent 691d11f commit df967f6

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

torchrec/distributed/planner/types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,18 @@ class DeviceHardware:
144144
storage: Storage
145145
perf: Perf
146146

147+
def __hash__(self) -> int:
148+
return hash((self.rank, self.storage, self.perf))
149+
150+
def __eq__(self, other: object) -> bool:
151+
if not isinstance(other, DeviceHardware):
152+
return False
153+
return (
154+
self.rank == other.rank
155+
and self.storage == other.storage
156+
and self.perf == other.perf
157+
)
158+
147159

148160
class CustomTopologyData:
149161
"""
@@ -247,6 +259,17 @@ def get_bw(
247259
else:
248260
return self.inter_host_bw
249261

262+
def __hash__(self) -> int:
263+
return hash((self._inter_host_bw, self._intra_host_bw))
264+
265+
def __eq__(self, other: object) -> bool:
266+
if not isinstance(other, BasicCommsBandwidths):
267+
return False
268+
return (
269+
self._inter_host_bw == other._inter_host_bw
270+
and self._intra_host_bw == other._intra_host_bw
271+
)
272+
250273

251274
class Topology:
252275
"""

0 commit comments

Comments
 (0)