From d57460d0842bafef5d21616f6cf28a01f6d34438 Mon Sep 17 00:00:00 2001 From: Raahul Kalyaan Jakka Date: Mon, 3 Nov 2025 10:29:40 -0800 Subject: [PATCH] Added unit test for Heuristic Storage Reservation Summary: **Context:** Heuristic Storage reservation is a common component for all planner that checks if the given module along with the constraints can be sharded across the topology. **In this diff:** We added a UT to validate the error for storage use in the storage reservation process. If the given module is larger than the provided topology. We need to OOM the process asap with appropriate error to notify the PG Differential Revision: D85892579 --- .../tests/test_storage_reservations.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/planner/tests/test_storage_reservations.py b/torchrec/distributed/planner/tests/test_storage_reservations.py index 4bc2b40f3..128f9f904 100644 --- a/torchrec/distributed/planner/tests/test_storage_reservations.py +++ b/torchrec/distributed/planner/tests/test_storage_reservations.py @@ -19,7 +19,7 @@ _get_module_size, HeuristicalStorageReservation, ) -from torchrec.distributed.planner.types import Topology +from torchrec.distributed.planner.types import PlannerError, PlannerErrorType, Topology from torchrec.distributed.test_utils.test_model import TestTowerInteraction from torchrec.distributed.types import ModuleSharder @@ -36,6 +36,36 @@ def __init__(self, shardable_sparse: nn.Module) -> None: class TestHeuristicalStorageReservation(unittest.TestCase): + + def test_validate_storage_reservations_errors(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=1_000_000, + embedding_dim=1024, + name="table_0", + feature_names=["feature_0"], + ), + ] + + ebc = EmbeddingBagCollection(tables) + model = TestModel(shardable_sparse=ebc) + + # Reserving 100% of HBM to make sure the heuristic storage reservation fails + heuristical_storage_reservation = HeuristicalStorageReservation(percentage=1) + with self.assertRaises(PlannerError) as context: + heuristical_storage_reservation.reserve( + topology=Topology(world_size=1, compute_device="cuda"), + batch_size=1024, + module=model, + sharders=cast( + List[ModuleSharder[nn.Module]], [EmbeddingBagCollectionSharder()] + ), + ) + + self.assertEqual( + context.exception.error_type, PlannerErrorType.INSUFFICIENT_STORAGE + ) + def test_storage_reservations_ebc(self) -> None: tables = [ EmbeddingBagConfig(