diff --git a/torchrec/distributed/planner/tests/test_storage_reservations.py b/torchrec/distributed/planner/tests/test_storage_reservations.py index 4bc2b40f3..128f9f904 100644 --- a/torchrec/distributed/planner/tests/test_storage_reservations.py +++ b/torchrec/distributed/planner/tests/test_storage_reservations.py @@ -19,7 +19,7 @@ _get_module_size, HeuristicalStorageReservation, ) -from torchrec.distributed.planner.types import Topology +from torchrec.distributed.planner.types import PlannerError, PlannerErrorType, Topology from torchrec.distributed.test_utils.test_model import TestTowerInteraction from torchrec.distributed.types import ModuleSharder @@ -36,6 +36,36 @@ def __init__(self, shardable_sparse: nn.Module) -> None: class TestHeuristicalStorageReservation(unittest.TestCase): + + def test_validate_storage_reservations_errors(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=1_000_000, + embedding_dim=1024, + name="table_0", + feature_names=["feature_0"], + ), + ] + + ebc = EmbeddingBagCollection(tables) + model = TestModel(shardable_sparse=ebc) + + # Reserving 100% of HBM to make sure the heuristic storage reservation fails + heuristical_storage_reservation = HeuristicalStorageReservation(percentage=1) + with self.assertRaises(PlannerError) as context: + heuristical_storage_reservation.reserve( + topology=Topology(world_size=1, compute_device="cuda"), + batch_size=1024, + module=model, + sharders=cast( + List[ModuleSharder[nn.Module]], [EmbeddingBagCollectionSharder()] + ), + ) + + self.assertEqual( + context.exception.error_type, PlannerErrorType.INSUFFICIENT_STORAGE + ) + def test_storage_reservations_ebc(self) -> None: tables = [ EmbeddingBagConfig(