1818
1919import copy
2020from abc import ABC , abstractmethod
21- from dataclasses import dataclass , fields
21+ from dataclasses import dataclass , field , fields
2222from typing import Any , Dict , List , Optional , Tuple , Type , Union
2323
2424import torch
3131from torchrec .distributed .planner .planners import HeteroEmbeddingShardingPlanner
3232from torchrec .distributed .sharding_plan import get_default_sharders
3333from torchrec .distributed .test_utils .test_model import (
34+ TestOverArchLarge ,
3435 TestSparseNN ,
3536 TestTowerCollectionSparseNN ,
3637 TestTowerSparseNN ,
@@ -51,8 +52,9 @@ class BaseModelConfig(ABC):
5152 and requires each concrete implementation to provide its own generate_model method.
5253 """
5354
54- # Common parameters for all model types
55- num_float_features : int # we assume all model arch has a single dense feature layer
55+ ## Common parameters for all model types, please do not set default values here
56+ # we assume all model arch has a single dense feature layer
57+ num_float_features : int
5658
5759 @abstractmethod
5860 def generate_model (
@@ -80,12 +82,12 @@ def generate_model(
8082class TestSparseNNConfig (BaseModelConfig ):
8183 """Configuration for TestSparseNN model."""
8284
83- embedding_groups : Optional [Dict [str , List [str ]]]
84- feature_processor_modules : Optional [Dict [str , torch .nn .Module ]]
85- max_feature_lengths : Optional [Dict [str , int ]]
86- over_arch_clazz : Type [nn .Module ]
87- postproc_module : Optional [nn .Module ]
88- zch : bool
85+ embedding_groups : Optional [Dict [str , List [str ]]] = None
86+ feature_processor_modules : Optional [Dict [str , torch .nn .Module ]] = None
87+ max_feature_lengths : Optional [Dict [str , int ]] = None
88+ over_arch_clazz : Type [nn .Module ] = TestOverArchLarge
89+ postproc_module : Optional [nn .Module ] = None
90+ zch : bool = False
8991
9092 def generate_model (
9193 self ,
@@ -113,8 +115,8 @@ def generate_model(
113115class TestTowerSparseNNConfig (BaseModelConfig ):
114116 """Configuration for TestTowerSparseNN model."""
115117
116- embedding_groups : Optional [Dict [str , List [str ]]]
117- feature_processor_modules : Optional [Dict [str , torch .nn .Module ]]
118+ embedding_groups : Optional [Dict [str , List [str ]]] = None
119+ feature_processor_modules : Optional [Dict [str , torch .nn .Module ]] = None
118120
119121 def generate_model (
120122 self ,
@@ -138,8 +140,8 @@ def generate_model(
138140class TestTowerCollectionSparseNNConfig (BaseModelConfig ):
139141 """Configuration for TestTowerCollectionSparseNN model."""
140142
141- embedding_groups : Optional [Dict [str , List [str ]]]
142- feature_processor_modules : Optional [Dict [str , torch .nn .Module ]]
143+ embedding_groups : Optional [Dict [str , List [str ]]] = None
144+ feature_processor_modules : Optional [Dict [str , torch .nn .Module ]] = None
143145
144146 def generate_model (
145147 self ,
@@ -163,8 +165,8 @@ def generate_model(
163165class DeepFMConfig (BaseModelConfig ):
164166 """Configuration for DeepFM model."""
165167
166- hidden_layer_size : int
167- deep_fm_dimension : int
168+ hidden_layer_size : int = 20
169+ deep_fm_dimension : int = 5
168170
169171 def generate_model (
170172 self ,
@@ -189,8 +191,8 @@ def generate_model(
189191class DLRMConfig (BaseModelConfig ):
190192 """Configuration for DLRM model."""
191193
192- dense_arch_layer_sizes : List [int ]
193- over_arch_layer_sizes : List [int ]
194+ dense_arch_layer_sizes : List [int ] = field ( default_factory = lambda : [ 20 , 128 ])
195+ over_arch_layer_sizes : List [int ] = field ( default_factory = lambda : [ 5 , 1 ])
194196
195197 def generate_model (
196198 self ,
@@ -213,7 +215,9 @@ def generate_model(
213215
214216# pyre-ignore[2]: Missing parameter annotation
215217def create_model_config (model_name : str , ** kwargs ) -> BaseModelConfig :
216-
218+ """
219+ deprecated function, please use ModelSelectionConfig.create_model_config instead
220+ """
217221 model_configs = {
218222 "test_sparse_nn" : TestSparseNNConfig ,
219223 "test_tower_sparse_nn" : TestTowerSparseNNConfig ,
@@ -309,3 +313,39 @@ def generate_sharded_model_and_optimizer(
309313 optimizer = optimizer_class (dense_params , ** optimizer_kwargs )
310314
311315 return sharded_model , optimizer
316+
317+
318+ @dataclass
319+ class ModelSelectionConfig :
320+ model_name : str = "test_sparse_nn"
321+ model_config : Dict [str , Any ] = field (
322+ default_factory = lambda : {"num_float_features" : 10 }
323+ )
324+
325+ def get_model_config_class (self ) -> Type [BaseModelConfig ]:
326+ match self .model_name :
327+ case "test_sparse_nn" :
328+ return TestSparseNNConfig
329+ case "test_tower_sparse_nn" :
330+ return TestTowerSparseNNConfig
331+ case "test_tower_collection_sparse_nn" :
332+ return TestTowerCollectionSparseNNConfig
333+ case "deepfm" :
334+ return DeepFMConfig
335+ case "dlrm" :
336+ return DLRMConfig
337+ case _:
338+ raise ValueError (f"Unknown model name: { self .model_name } " )
339+
340+ def create_model_config (self ) -> BaseModelConfig :
341+ config_class = self .get_model_config_class ()
342+ valid_field_names = {field .name for field in fields (config_class )}
343+ filtered_kwargs = {
344+ k : v for k , v in self .model_config .items () if k in valid_field_names
345+ }
346+ # pyre-ignore[45]: Invalid class instantiation
347+ return config_class (** filtered_kwargs )
348+
349+ def create_test_model (self , ** kwargs : Any ) -> nn .Module :
350+ model_config = self .create_model_config ()
351+ return model_config .generate_model (** kwargs )
0 commit comments