Skip to content

Commit dbabd25

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
create ModelSelectionConfig and fix bugs (#3467)
Summary: Pull Request resolved: #3467 # context * move `ModelSelectionConfig` to test_utils.py and make `create_model_config` a class method * fix previously failed test and some pyre ignore issues * bump pre-commit check version from python 3.9 to python 3.12 Reviewed By: spmex Differential Revision: D84755189 fbshipit-source-id: ef060c5b94f25c04b35c0a3156ebf2c0d736e636
1 parent bfcbd1e commit dbabd25

File tree

11 files changed

+78
-101
lines changed

11 files changed

+78
-101
lines changed

.github/workflows/pre-commit.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
- name: Setup Python
1919
uses: actions/setup-python@v5
2020
with:
21-
python-version: 3.9
21+
python-version: 3.12
2222
architecture: x64
2323
packages: |
2424
ufmt==2.5.1

examples/retrieval/knn_index.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def get_index(
2121
num_subquantizers: int,
2222
bits_per_code: int,
2323
device: Optional[torch.device] = None,
24-
# pyre-ignore[11]
2524
) -> Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ]:
2625
"""
2726
returns a FAISS IVFPQ index, placed on the device passed in
@@ -39,25 +38,19 @@ def get_index(
3938
4039
"""
4140
if device is not None and device.type == "cuda":
42-
# pyre-fixme[16]
4341
res = faiss.StandardGpuResources()
44-
# pyre-fixme[16]
4542
config = faiss.GpuIndexIVFPQConfig()
46-
# pyre-ignore[16]
4743
index = faiss.GpuIndexIVFPQ(
4844
res,
4945
embedding_dim,
5046
num_centroids,
5147
num_subquantizers,
5248
bits_per_code,
53-
# pyre-fixme[16]
5449
faiss.METRIC_L2,
5550
config,
5651
)
5752
else:
58-
# pyre-fixme[16]
5953
quantizer = faiss.IndexFlatL2(embedding_dim)
60-
# pyre-fixme[16]
6154
index = faiss.IndexIVFPQ(
6255
quantizer,
6356
embedding_dim,

examples/retrieval/modules/two_tower.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ class TwoTowerRetrieval(nn.Module):
169169

170170
def __init__(
171171
self,
172-
# pyre-ignore[11]
173172
faiss_index: Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ],
174173
query_ebc: EmbeddingBagCollection,
175174
candidate_ebc: EmbeddingBagCollection,
@@ -222,6 +221,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor:
222221
(batch_size, self.k), device=self.device, dtype=torch.int64
223222
)
224223
query_embedding = query_embedding.to(torch.float32) # required by faiss
224+
# pyre-ignore[19]
225225
self.faiss_index.search(query_embedding, self.k, distances, candidates)
226226

227227
# candidate lookup

examples/retrieval/two_tower_retrieval.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,9 @@ def infer(
128128
retrieval_sd = None
129129
if load_dir is not None:
130130
load_dir = load_dir.rstrip("/")
131-
# pyre-ignore[16]
132131
index = faiss.index_cpu_to_gpu(
133-
# pyre-ignore[16]
134132
faiss.StandardGpuResources(),
135133
faiss_device_idx,
136-
# pyre-ignore[16]
137134
faiss.read_index(f"{load_dir}/faiss.index"),
138135
)
139136
two_tower_sd = torch.load(f"{load_dir}/model.pt", weights_only=True)
@@ -158,7 +155,13 @@ def infer(
158155
index.add(embeddings)
159156

160157
retrieval_model = TwoTowerRetrieval(
161-
index, ebcs[0], ebcs[1], layer_sizes, k, device, dtype=torch.float16
158+
index, # pyre-ignore[6]
159+
ebcs[0],
160+
ebcs[1],
161+
layer_sizes,
162+
k,
163+
device,
164+
dtype=torch.float16,
162165
)
163166

164167
constraints = {}

examples/retrieval/two_tower_train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def train(
227227
model, dtype=torch.qint8, inplace=True
228228
)
229229
torch.save(quant_model.state_dict(), f"{save_dir}/model.pt")
230-
# pyre-ignore[16]
231230
faiss.write_index(index, f"{save_dir}/faiss.index")
232231

233232

torchrec/distributed/benchmark/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ class BenchFuncConfig:
849849
world_size: int
850850
num_profiles: int
851851
num_benchmarks: int
852-
profile_dir: str
852+
profile_dir: str = ""
853853
device_type: str = "cuda"
854854
pre_gpu_load: int = 0
855855
export_stacks: bool = False

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 6 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
See benchmark_pipeline_utils.py for step-by-step instructions.
2121
"""
2222

23-
from dataclasses import dataclass, field
24-
from typing import Dict, List, Optional, Type
23+
from dataclasses import dataclass
24+
from typing import List, Optional
2525

2626
import torch
2727
from fbgemm_gpu.split_embedding_configs import EmbOptimType
@@ -37,8 +37,8 @@
3737
from torchrec.distributed.test_utils.input_config import ModelInputConfig
3838
from torchrec.distributed.test_utils.model_config import (
3939
BaseModelConfig,
40-
create_model_config,
4140
generate_sharded_model_and_optimizer,
41+
ModelSelectionConfig,
4242
)
4343
from torchrec.distributed.test_utils.model_input import ModelInput
4444

@@ -49,7 +49,6 @@
4949
from torchrec.distributed.test_utils.pipeline_config import PipelineConfig
5050
from torchrec.distributed.test_utils.sharding_config import PlannerConfig
5151
from torchrec.distributed.test_utils.table_config import EmbeddingTablesConfig
52-
from torchrec.distributed.test_utils.test_model import TestOverArchLarge
5352
from torchrec.distributed.train_pipeline import TrainPipeline
5453
from torchrec.distributed.types import ShardingType
5554
from torchrec.modules.embedding_configs import EmbeddingBagConfig
@@ -94,11 +93,11 @@ class RunOptions(BenchFuncConfig):
9493
"""
9594

9695
world_size: int = 2
96+
batch_size: int = 1024 * 32
97+
num_float_features: int = 10
9798
num_batches: int = 10
9899
sharding_type: ShardingType = ShardingType.TABLE_WISE
99100
input_type: str = "kjt"
100-
name: str = ""
101-
profile_dir: str = ""
102101
num_benchmarks: int = 5
103102
num_profiles: int = 2
104103
num_poolings: Optional[List[float]] = None
@@ -113,39 +112,6 @@ class RunOptions(BenchFuncConfig):
113112
export_stacks: bool = False
114113

115114

116-
@dataclass
117-
class ModelSelectionConfig:
118-
model_name: str = "test_sparse_nn"
119-
120-
# Common config for all model types
121-
batch_size: int = 1024 * 32
122-
batch_sizes: Optional[List[int]] = None
123-
num_float_features: int = 10
124-
feature_pooling_avg: int = 10
125-
use_offsets: bool = False
126-
dev_str: str = ""
127-
long_kjt_indices: bool = True
128-
long_kjt_offsets: bool = True
129-
long_kjt_lengths: bool = True
130-
pin_memory: bool = True
131-
132-
# TestSparseNN specific config
133-
embedding_groups: Optional[Dict[str, List[str]]] = None
134-
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None
135-
max_feature_lengths: Optional[Dict[str, int]] = None
136-
over_arch_clazz: Type[nn.Module] = TestOverArchLarge
137-
postproc_module: Optional[nn.Module] = None
138-
zch: bool = False
139-
140-
# DeepFM specific config
141-
hidden_layer_size: int = 20
142-
deep_fm_dimension: int = 5
143-
144-
# DLRM specific config
145-
dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 128])
146-
over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1])
147-
148-
149115
# single-rank runner
150116
def runner(
151117
rank: int,
@@ -303,35 +269,9 @@ def main(
303269
pipeline_config: PipelineConfig,
304270
input_config: ModelInputConfig,
305271
planner_config: PlannerConfig,
306-
model_config: Optional[BaseModelConfig] = None,
307272
) -> None:
308273
tables, weighted_tables, *_ = table_config.generate_tables()
309-
310-
if model_config is None:
311-
model_config = create_model_config(
312-
model_name=model_selection.model_name,
313-
batch_size=model_selection.batch_size,
314-
batch_sizes=model_selection.batch_sizes,
315-
num_float_features=model_selection.num_float_features,
316-
feature_pooling_avg=model_selection.feature_pooling_avg,
317-
use_offsets=model_selection.use_offsets,
318-
dev_str=model_selection.dev_str,
319-
long_kjt_indices=model_selection.long_kjt_indices,
320-
long_kjt_offsets=model_selection.long_kjt_offsets,
321-
long_kjt_lengths=model_selection.long_kjt_lengths,
322-
pin_memory=model_selection.pin_memory,
323-
embedding_groups=model_selection.embedding_groups,
324-
feature_processor_modules=model_selection.feature_processor_modules,
325-
max_feature_lengths=model_selection.max_feature_lengths,
326-
over_arch_clazz=model_selection.over_arch_clazz,
327-
postproc_module=model_selection.postproc_module,
328-
zch=model_selection.zch,
329-
hidden_layer_size=model_selection.hidden_layer_size,
330-
deep_fm_dimension=model_selection.deep_fm_dimension,
331-
dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes,
332-
over_arch_layer_sizes=model_selection.over_arch_layer_sizes,
333-
)
334-
274+
model_config = model_selection.create_model_config()
335275
# launch trainers
336276
run_multi_process_func(
337277
func=runner,

torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ RunOptions:
1010
# export_stacks: True # enable this to export stack traces
1111
PipelineConfig:
1212
pipeline: "sparse"
13+
ModelInputConfig:
14+
feature_pooling_avg: 10
1315
EmbeddingTablesConfig:
1416
num_unweighted_features: 100
1517
num_weighted_features: 100

torchrec/distributed/test_utils/model_config.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import copy
2020
from abc import ABC, abstractmethod
21-
from dataclasses import dataclass, fields
21+
from dataclasses import dataclass, field, fields
2222
from typing import Any, Dict, List, Optional, Tuple, Type, Union
2323

2424
import torch
@@ -31,6 +31,7 @@
3131
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
3232
from torchrec.distributed.sharding_plan import get_default_sharders
3333
from torchrec.distributed.test_utils.test_model import (
34+
TestOverArchLarge,
3435
TestSparseNN,
3536
TestTowerCollectionSparseNN,
3637
TestTowerSparseNN,
@@ -51,8 +52,9 @@ class BaseModelConfig(ABC):
5152
and requires each concrete implementation to provide its own generate_model method.
5253
"""
5354

54-
# Common parameters for all model types
55-
num_float_features: int # we assume all model arch has a single dense feature layer
55+
## Common parameters for all model types, please do not set default values here
56+
# we assume all model arch has a single dense feature layer
57+
num_float_features: int
5658

5759
@abstractmethod
5860
def generate_model(
@@ -80,12 +82,12 @@ def generate_model(
8082
class TestSparseNNConfig(BaseModelConfig):
8183
"""Configuration for TestSparseNN model."""
8284

83-
embedding_groups: Optional[Dict[str, List[str]]]
84-
feature_processor_modules: Optional[Dict[str, torch.nn.Module]]
85-
max_feature_lengths: Optional[Dict[str, int]]
86-
over_arch_clazz: Type[nn.Module]
87-
postproc_module: Optional[nn.Module]
88-
zch: bool
85+
embedding_groups: Optional[Dict[str, List[str]]] = None
86+
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None
87+
max_feature_lengths: Optional[Dict[str, int]] = None
88+
over_arch_clazz: Type[nn.Module] = TestOverArchLarge
89+
postproc_module: Optional[nn.Module] = None
90+
zch: bool = False
8991

9092
def generate_model(
9193
self,
@@ -113,8 +115,8 @@ def generate_model(
113115
class TestTowerSparseNNConfig(BaseModelConfig):
114116
"""Configuration for TestTowerSparseNN model."""
115117

116-
embedding_groups: Optional[Dict[str, List[str]]]
117-
feature_processor_modules: Optional[Dict[str, torch.nn.Module]]
118+
embedding_groups: Optional[Dict[str, List[str]]] = None
119+
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None
118120

119121
def generate_model(
120122
self,
@@ -138,8 +140,8 @@ def generate_model(
138140
class TestTowerCollectionSparseNNConfig(BaseModelConfig):
139141
"""Configuration for TestTowerCollectionSparseNN model."""
140142

141-
embedding_groups: Optional[Dict[str, List[str]]]
142-
feature_processor_modules: Optional[Dict[str, torch.nn.Module]]
143+
embedding_groups: Optional[Dict[str, List[str]]] = None
144+
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None
143145

144146
def generate_model(
145147
self,
@@ -163,8 +165,8 @@ def generate_model(
163165
class DeepFMConfig(BaseModelConfig):
164166
"""Configuration for DeepFM model."""
165167

166-
hidden_layer_size: int
167-
deep_fm_dimension: int
168+
hidden_layer_size: int = 20
169+
deep_fm_dimension: int = 5
168170

169171
def generate_model(
170172
self,
@@ -189,8 +191,8 @@ def generate_model(
189191
class DLRMConfig(BaseModelConfig):
190192
"""Configuration for DLRM model."""
191193

192-
dense_arch_layer_sizes: List[int]
193-
over_arch_layer_sizes: List[int]
194+
dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 128])
195+
over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1])
194196

195197
def generate_model(
196198
self,
@@ -213,7 +215,9 @@ def generate_model(
213215

214216
# pyre-ignore[2]: Missing parameter annotation
215217
def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
216-
218+
"""
219+
deprecated function, please use ModelSelectionConfig.create_model_config instead
220+
"""
217221
model_configs = {
218222
"test_sparse_nn": TestSparseNNConfig,
219223
"test_tower_sparse_nn": TestTowerSparseNNConfig,
@@ -309,3 +313,39 @@ def generate_sharded_model_and_optimizer(
309313
optimizer = optimizer_class(dense_params, **optimizer_kwargs)
310314

311315
return sharded_model, optimizer
316+
317+
318+
@dataclass
319+
class ModelSelectionConfig:
320+
model_name: str = "test_sparse_nn"
321+
model_config: Dict[str, Any] = field(
322+
default_factory=lambda: {"num_float_features": 10}
323+
)
324+
325+
def get_model_config_class(self) -> Type[BaseModelConfig]:
326+
match self.model_name:
327+
case "test_sparse_nn":
328+
return TestSparseNNConfig
329+
case "test_tower_sparse_nn":
330+
return TestTowerSparseNNConfig
331+
case "test_tower_collection_sparse_nn":
332+
return TestTowerCollectionSparseNNConfig
333+
case "deepfm":
334+
return DeepFMConfig
335+
case "dlrm":
336+
return DLRMConfig
337+
case _:
338+
raise ValueError(f"Unknown model name: {self.model_name}")
339+
340+
def create_model_config(self) -> BaseModelConfig:
341+
config_class = self.get_model_config_class()
342+
valid_field_names = {field.name for field in fields(config_class)}
343+
filtered_kwargs = {
344+
k: v for k, v in self.model_config.items() if k in valid_field_names
345+
}
346+
# pyre-ignore[45]: Invalid class instantiation
347+
return config_class(**filtered_kwargs)
348+
349+
def create_test_model(self, **kwargs: Any) -> nn.Module:
350+
model_config = self.create_model_config()
351+
return model_config.generate_model(**kwargs)

torchrec/distributed/test_utils/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2364,7 +2364,7 @@ def __init__(
23642364
if device is None:
23652365
device = torch.device("cpu")
23662366
if max_sequence_length is None:
2367-
max_sequence_length = 10
2367+
max_sequence_length = 20
23682368
if dense_arch_out_size is None:
23692369
dense_arch_out_size = DENSE_LAYER_OUT_SIZE
23702370
if over_arch_out_size is None:

0 commit comments

Comments
 (0)