|
19 | 19 | import copy |
20 | 20 | from abc import ABC, abstractmethod |
21 | 21 | from dataclasses import dataclass, fields |
22 | | -from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union |
| 22 | +from typing import Any, Dict, List, Optional, Tuple, Type, Union |
23 | 23 |
|
24 | 24 | import torch |
25 | 25 | import torch.distributed as dist |
26 | 26 |
|
27 | 27 | from torch import nn, optim |
28 | 28 | from torch.optim import Optimizer |
29 | 29 | from torchrec.distributed import DistributedModelParallel |
30 | | -from torchrec.distributed.embedding_types import EmbeddingComputeKernel |
31 | | -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology |
32 | | -from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR |
| 30 | +from torchrec.distributed.planner import EmbeddingShardingPlanner |
33 | 31 | from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner |
34 | | -from torchrec.distributed.planner.types import ParameterConstraints |
35 | | -from torchrec.distributed.test_utils.model_input import ModelInput |
| 32 | +from torchrec.distributed.sharding_plan import get_default_sharders |
36 | 33 | from torchrec.distributed.test_utils.test_model import ( |
37 | | - TestEBCSharder, |
38 | 34 | TestSparseNN, |
39 | 35 | TestTowerCollectionSparseNN, |
40 | 36 | TestTowerSparseNN, |
41 | 37 | ) |
42 | | -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType |
| 38 | +from torchrec.distributed.types import ShardingEnv |
43 | 39 | from torchrec.models.deepfm import SimpleDeepFMNNWrapper |
44 | 40 | from torchrec.models.dlrm import DLRMWrapper |
45 | 41 | from torchrec.modules.embedding_configs import EmbeddingBagConfig |
@@ -240,137 +236,8 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig: |
240 | 236 | return model_class(**filtered_kwargs) |
241 | 237 |
|
242 | 238 |
|
243 | | -def generate_data( |
244 | | - tables: List[EmbeddingBagConfig], |
245 | | - weighted_tables: List[EmbeddingBagConfig], |
246 | | - model_config: BaseModelConfig, |
247 | | - batch_sizes: List[int], |
248 | | -) -> List[ModelInput]: |
249 | | - """ |
250 | | - Generate model input data for benchmarking. |
251 | | -
|
252 | | - Args: |
253 | | - tables: List of unweighted embedding tables |
254 | | - weighted_tables: List of weighted embedding tables |
255 | | - model_config: Configuration for model generation |
256 | | - num_batches: Number of batches to generate |
257 | | -
|
258 | | - Returns: |
259 | | - A list of ModelInput objects representing the generated batches |
260 | | - """ |
261 | | - device = torch.device(model_config.dev_str) if model_config.dev_str else None |
262 | | - |
263 | | - return [ |
264 | | - ModelInput.generate( |
265 | | - batch_size=batch_size, |
266 | | - tables=tables, |
267 | | - weighted_tables=weighted_tables, |
268 | | - num_float_features=model_config.num_float_features, |
269 | | - pooling_avg=model_config.feature_pooling_avg, |
270 | | - use_offsets=model_config.use_offsets, |
271 | | - device=device, |
272 | | - indices_dtype=( |
273 | | - torch.int64 if model_config.long_kjt_indices else torch.int32 |
274 | | - ), |
275 | | - offsets_dtype=( |
276 | | - torch.int64 if model_config.long_kjt_offsets else torch.int32 |
277 | | - ), |
278 | | - lengths_dtype=( |
279 | | - torch.int64 if model_config.long_kjt_lengths else torch.int32 |
280 | | - ), |
281 | | - pin_memory=model_config.pin_memory, |
282 | | - ) |
283 | | - for batch_size in batch_sizes |
284 | | - ] |
285 | | - |
286 | | - |
287 | | -def generate_planner( |
288 | | - planner_type: str, |
289 | | - topology: Topology, |
290 | | - tables: Optional[List[EmbeddingBagConfig]], |
291 | | - weighted_tables: Optional[List[EmbeddingBagConfig]], |
292 | | - sharding_type: ShardingType, |
293 | | - compute_kernel: EmbeddingComputeKernel, |
294 | | - batch_sizes: List[int], |
295 | | - pooling_factors: Optional[List[float]] = None, |
296 | | - num_poolings: Optional[List[float]] = None, |
297 | | -) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: |
298 | | - """ |
299 | | - Generate an embedding sharding planner based on the specified configuration. |
300 | | -
|
301 | | - Args: |
302 | | - planner_type: Type of planner to use ("embedding" or "hetero") |
303 | | - topology: Network topology for distributed training |
304 | | - tables: List of unweighted embedding tables |
305 | | - weighted_tables: List of weighted embedding tables |
306 | | - sharding_type: Strategy for sharding embedding tables |
307 | | - compute_kernel: Compute kernel to use for embedding tables |
308 | | - batch_sizes: Sizes of each batch |
309 | | - pooling_factors: Pooling factors for each feature of the table |
310 | | - num_poolings: Number of poolings for each feature of the table |
311 | | -
|
312 | | - Returns: |
313 | | - An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner |
314 | | -
|
315 | | - Raises: |
316 | | - RuntimeError: If an unknown planner type is specified |
317 | | - """ |
318 | | - # Create parameter constraints for tables |
319 | | - constraints = {} |
320 | | - num_batches = len(batch_sizes) |
321 | | - |
322 | | - if pooling_factors is None: |
323 | | - pooling_factors = [POOLING_FACTOR] * num_batches |
324 | | - |
325 | | - if num_poolings is None: |
326 | | - num_poolings = [NUM_POOLINGS] * num_batches |
327 | | - |
328 | | - assert ( |
329 | | - len(pooling_factors) == num_batches and len(num_poolings) == num_batches |
330 | | - ), "The length of pooling_factors and num_poolings must match the number of batches." |
331 | | - |
332 | | - if tables is not None: |
333 | | - for table in tables: |
334 | | - constraints[table.name] = ParameterConstraints( |
335 | | - sharding_types=[sharding_type.value], |
336 | | - compute_kernels=[compute_kernel.value], |
337 | | - device_group="cuda", |
338 | | - pooling_factors=pooling_factors, |
339 | | - num_poolings=num_poolings, |
340 | | - batch_sizes=batch_sizes, |
341 | | - ) |
342 | | - |
343 | | - if weighted_tables is not None: |
344 | | - for table in weighted_tables: |
345 | | - constraints[table.name] = ParameterConstraints( |
346 | | - sharding_types=[sharding_type.value], |
347 | | - compute_kernels=[compute_kernel.value], |
348 | | - device_group="cuda", |
349 | | - pooling_factors=pooling_factors, |
350 | | - num_poolings=num_poolings, |
351 | | - batch_sizes=batch_sizes, |
352 | | - is_weighted=True, |
353 | | - ) |
354 | | - |
355 | | - if planner_type == "embedding": |
356 | | - return EmbeddingShardingPlanner( |
357 | | - topology=topology, |
358 | | - constraints=constraints if constraints else None, |
359 | | - ) |
360 | | - elif planner_type == "hetero": |
361 | | - topology_groups = {"cuda": topology} |
362 | | - return HeteroEmbeddingShardingPlanner( |
363 | | - topology_groups=topology_groups, |
364 | | - constraints=constraints if constraints else None, |
365 | | - ) |
366 | | - else: |
367 | | - raise RuntimeError(f"Unknown planner type: {planner_type}") |
368 | | - |
369 | | - |
370 | 239 | def generate_sharded_model_and_optimizer( |
371 | 240 | model: nn.Module, |
372 | | - sharding_type: str, |
373 | | - kernel_type: str, |
374 | 241 | pg: dist.ProcessGroup, |
375 | 242 | device: torch.device, |
376 | 243 | fused_params: Dict[str, Any], |
@@ -404,12 +271,7 @@ def generate_sharded_model_and_optimizer( |
404 | 271 | Returns: |
405 | 272 | Tuple of sharded model and optimizer |
406 | 273 | """ |
407 | | - sharder = TestEBCSharder( |
408 | | - sharding_type=sharding_type, |
409 | | - kernel_type=kernel_type, |
410 | | - fused_params=fused_params, |
411 | | - ) |
412 | | - sharders = [cast(ModuleSharder[nn.Module], sharder)] |
| 274 | + sharders = get_default_sharders() |
413 | 275 |
|
414 | 276 | # Use planner if provided |
415 | 277 | plan = None |
|
0 commit comments