Skip to content

Commit cf63c52

Browse files
Justin Yangmeta-codesync[bot]
authored andcommitted
Back out "Add row based sharding support for FeaturedProcessedEBC" (#3537)
Summary: Pull Request resolved: #3537 Original commit changeset: 4a8ad3bc6d14 Original Phabricator Diff: D82248545 Reviewed By: tyleretzel, sarckk, really121, aliafzal Differential Revision: D86779963 fbshipit-source-id: a0cdbd03d249a8976d841abe8c417cbbe549669c
1 parent 67ad1e1 commit cf63c52

File tree

8 files changed

+25
-330
lines changed

8 files changed

+25
-330
lines changed

torchrec/distributed/fp_embeddingbag.py

Lines changed: 9 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,7 @@
88
# pyre-strict
99

1010
from functools import partial
11-
from typing import (
12-
Any,
13-
Dict,
14-
Iterator,
15-
List,
16-
Mapping,
17-
Optional,
18-
Tuple,
19-
Type,
20-
TypeVar,
21-
Union,
22-
)
11+
from typing import Any, Dict, Iterator, List, Optional, Type, Union
2312

2413
import torch
2514
from torch import nn
@@ -42,20 +31,14 @@
4231
ShardingEnv,
4332
ShardingType,
4433
)
45-
from torchrec.distributed.utils import (
46-
append_prefix,
47-
init_parameters,
48-
modify_input_for_feature_processor,
49-
)
34+
from torchrec.distributed.utils import append_prefix, init_parameters
5035
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
5136
from torchrec.modules.fp_embedding_modules import (
5237
apply_feature_processors_to_kjt,
5338
FeatureProcessedEmbeddingBagCollection,
5439
)
5540
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
5641

57-
_T = TypeVar("_T")
58-
5942

6043
def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor:
6144
kt._values.add_(no_op_tensor)
@@ -91,16 +74,6 @@ def __init__(
9174
)
9275
)
9376

94-
self._row_wise_sharded: bool = False
95-
for param_sharding in table_name_to_parameter_sharding.values():
96-
if param_sharding.sharding_type in [
97-
ShardingType.ROW_WISE.value,
98-
ShardingType.TABLE_ROW_WISE.value,
99-
ShardingType.GRID_SHARD.value,
100-
]:
101-
self._row_wise_sharded = True
102-
break
103-
10477
self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups
10578

10679
self._is_collection: bool = False
@@ -123,11 +96,6 @@ def __init__(
12396
def input_dist(
12497
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
12598
) -> Awaitable[Awaitable[KJTList]]:
126-
if not self.is_pipelined and self._row_wise_sharded:
127-
# transform input to support row based sharding when not pipelined
128-
modify_input_for_feature_processor(
129-
features, self._feature_processors, self._is_collection
130-
)
13199
return self._embedding_bag_collection.input_dist(ctx, features)
132100

133101
def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
@@ -137,7 +105,10 @@ def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
137105
kjt_list.append(self._feature_processors(features))
138106
else:
139107
kjt_list.append(
140-
apply_feature_processors_to_kjt(features, self._feature_processors)
108+
apply_feature_processors_to_kjt(
109+
features,
110+
self._feature_processors,
111+
)
141112
)
142113
return KJTList(kjt_list)
143114

@@ -146,6 +117,7 @@ def compute(
146117
ctx: EmbeddingBagCollectionContext,
147118
dist_input: KJTList,
148119
) -> List[torch.Tensor]:
120+
149121
fp_features = self.apply_feature_processors_to_kjt_list(dist_input)
150122
return self._embedding_bag_collection.compute(ctx, fp_features)
151123

@@ -194,18 +166,6 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
194166
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
195167
self._embedding_bag_collection._initialize_torch_state(skip_registering)
196168

197-
def preprocess_input(
198-
self, args: List[_T], kwargs: Mapping[str, _T]
199-
) -> Tuple[List[_T], Mapping[str, _T]]:
200-
for x in args + list(kwargs.values()):
201-
if isinstance(x, KeyedJaggedTensor):
202-
modify_input_for_feature_processor(
203-
features=x,
204-
feature_processors=self._feature_processors,
205-
is_collection=self._is_collection,
206-
)
207-
return args, kwargs
208-
209169

210170
class FeatureProcessedEmbeddingBagCollectionSharder(
211171
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]
@@ -231,6 +191,7 @@ def shard(
231191
device: Optional[torch.device] = None,
232192
module_fqn: Optional[str] = None,
233193
) -> ShardedFeatureProcessedEmbeddingBagCollection:
194+
234195
if device is None:
235196
device = torch.device("cuda")
236197

@@ -267,14 +228,12 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
267228
if compute_device_type in {"mtia"}:
268229
return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value]
269230

231+
# No row wise because position weighted FP and RW don't play well together.
270232
types = [
271233
ShardingType.DATA_PARALLEL.value,
272234
ShardingType.TABLE_WISE.value,
273235
ShardingType.COLUMN_WISE.value,
274236
ShardingType.TABLE_COLUMN_WISE.value,
275-
ShardingType.TABLE_ROW_WISE.value,
276-
ShardingType.ROW_WISE.value,
277-
ShardingType.GRID_SHARD.value,
278237
]
279238

280239
return types

torchrec/distributed/tests/test_fp_embeddingbag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase):
231231
def test_sharding_ebc(
232232
self, set_gradient_division: bool, use_dmp: bool, use_fp_collection: bool
233233
) -> None:
234+
234235
import hypothesis
235236

236237
# don't need to test entire matrix

torchrec/distributed/tests/test_fp_embeddingbag_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,7 @@ def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]:
8686
pred = torch.cat(
8787
[
8888
fp_ebc_out[key]
89-
for key in [
90-
"feature_0",
91-
"feature_1",
92-
"feature_2",
93-
"feature_3",
94-
]
89+
for key in ["feature_0", "feature_1", "feature_2", "feature_3"]
9590
],
9691
dim=1,
9792
)

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 1 addition & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
from torch._dynamo.testing import reduce_to_scalar_loss
2323
from torch._dynamo.utils import counters
2424
from torchrec.distributed import DistributedModelParallel
25-
from torchrec.distributed.embedding_types import (
26-
EmbeddingComputeKernel,
27-
EmbeddingTableConfig,
28-
)
25+
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
2926
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
3027
from torchrec.distributed.fp_embeddingbag import (
3128
FeatureProcessedEmbeddingBagCollectionSharder,
@@ -34,13 +31,8 @@
3431
from torchrec.distributed.model_parallel import DMPCollection
3532
from torchrec.distributed.sharding_plan import (
3633
construct_module_sharding_plan,
37-
row_wise,
3834
table_wise,
3935
)
40-
from torchrec.distributed.test_utils.multi_process import (
41-
MultiProcessContext,
42-
MultiProcessTestBase,
43-
)
4436
from torchrec.distributed.test_utils.test_model import (
4537
ModelInput,
4638
TestNegSamplingModule,
@@ -341,161 +333,6 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
341333
torch.testing.assert_close(pred_gpu.cpu(), pred)
342334

343335

344-
def fp_ebc_rw_sharding_test_runner(
345-
rank: int,
346-
world_size: int,
347-
tables: List[EmbeddingTableConfig],
348-
weighted_tables: List[EmbeddingTableConfig],
349-
data: List[Tuple[ModelInput, List[ModelInput]]],
350-
backend: str = "nccl",
351-
local_size: Optional[int] = None,
352-
) -> None:
353-
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
354-
assert ctx.pg is not None
355-
sharder = cast(
356-
ModuleSharder[nn.Module],
357-
FeatureProcessedEmbeddingBagCollectionSharder(),
358-
)
359-
360-
class DummyWrapper(nn.Module):
361-
def __init__(self, sparse_arch):
362-
super().__init__()
363-
self.m = sparse_arch
364-
365-
def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
366-
return self.m(model_input.idlist_features)
367-
368-
max_feature_lengths = [10, 10, 12, 12]
369-
sparse_arch = DummyWrapper(
370-
create_module_and_freeze(
371-
tables=tables, # pyre-ignore[6]
372-
device=ctx.device,
373-
use_fp_collection=False,
374-
max_feature_lengths=max_feature_lengths,
375-
)
376-
)
377-
378-
# compute_kernel = EmbeddingComputeKernel.FUSED.value
379-
module_sharding_plan = construct_module_sharding_plan(
380-
sparse_arch.m._fp_ebc,
381-
per_param_sharding={
382-
"table_0": row_wise(),
383-
"table_1": row_wise(),
384-
"table_2": row_wise(),
385-
"table_3": row_wise(),
386-
},
387-
world_size=2,
388-
device_type=ctx.device.type,
389-
sharder=sharder,
390-
)
391-
sharded_sparse_arch_pipeline = DistributedModelParallel(
392-
module=copy.deepcopy(sparse_arch),
393-
plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}),
394-
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6]
395-
sharders=[sharder],
396-
device=ctx.device,
397-
)
398-
sharded_sparse_arch_no_pipeline = DistributedModelParallel(
399-
module=copy.deepcopy(sparse_arch),
400-
plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}),
401-
env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6]
402-
sharders=[sharder],
403-
device=ctx.device,
404-
)
405-
406-
batches = []
407-
for d in data:
408-
batches.append(d[1][ctx.rank].to(ctx.device))
409-
dataloader = iter(batches)
410-
411-
optimizer_no_pipeline = optim.SGD(
412-
sharded_sparse_arch_no_pipeline.parameters(), lr=0.1
413-
)
414-
optimizer_pipeline = optim.SGD(
415-
sharded_sparse_arch_pipeline.parameters(), lr=0.1
416-
)
417-
418-
pipeline = TrainPipelineSparseDist(
419-
sharded_sparse_arch_pipeline,
420-
optimizer_pipeline,
421-
ctx.device,
422-
)
423-
424-
for batch in batches[:-2]:
425-
batch = batch.to(ctx.device)
426-
optimizer_no_pipeline.zero_grad()
427-
loss, pred = sharded_sparse_arch_no_pipeline(batch)
428-
loss.backward()
429-
optimizer_no_pipeline.step()
430-
431-
pred_pipeline = pipeline.progress(dataloader)
432-
torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu())
433-
434-
435-
class TrainPipelineGPUTest(MultiProcessTestBase):
436-
def setUp(self, backend: str = "nccl") -> None:
437-
super().setUp()
438-
439-
self.pipeline_class = TrainPipelineSparseDist
440-
num_features = 4
441-
num_weighted_features = 4
442-
self.tables = [
443-
EmbeddingBagConfig(
444-
num_embeddings=(i + 1) * 100,
445-
embedding_dim=(i + 1) * 4,
446-
name="table_" + str(i),
447-
feature_names=["feature_" + str(i)],
448-
)
449-
for i in range(num_features)
450-
]
451-
self.weighted_tables = [
452-
EmbeddingBagConfig(
453-
num_embeddings=(i + 1) * 100,
454-
embedding_dim=(i + 1) * 4,
455-
name="weighted_table_" + str(i),
456-
feature_names=["weighted_feature_" + str(i)],
457-
)
458-
for i in range(num_weighted_features)
459-
]
460-
461-
self.backend = backend
462-
if torch.cuda.is_available():
463-
self.device = torch.device("cuda")
464-
else:
465-
self.device = torch.device("cpu")
466-
467-
if self.backend == "nccl" and self.device == torch.device("cpu"):
468-
self.skipTest("NCCL not supported on CPUs.")
469-
470-
def _generate_data(
471-
self,
472-
num_batches: int = 5,
473-
batch_size: int = 1,
474-
max_feature_lengths: Optional[List[int]] = None,
475-
) -> List[Tuple[ModelInput, List[ModelInput]]]:
476-
return [
477-
ModelInput.generate(
478-
tables=self.tables,
479-
weighted_tables=self.weighted_tables,
480-
batch_size=batch_size,
481-
world_size=2,
482-
num_float_features=10,
483-
max_feature_lengths=max_feature_lengths,
484-
)
485-
for i in range(num_batches)
486-
]
487-
488-
def test_fp_ebc_rw(self) -> None:
489-
data = self._generate_data(max_feature_lengths=[10, 10, 12, 12])
490-
self._run_multi_process_test(
491-
callable=fp_ebc_rw_sharding_test_runner,
492-
world_size=2,
493-
tables=self.tables,
494-
weighted_tables=self.weighted_tables,
495-
data=data,
496-
)
497-
498-
499336
class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
500337
# pyre-fixme[56]: Pyre was not able to infer the type of argument
501338
@unittest.skipIf(

torchrec/distributed/train_pipeline/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ def _start_data_dist(
169169
# and this info was done in the _rewrite_model by tracing the
170170
# entire model to get the arg_info_list
171171
args, kwargs = forward.args.build_args_kwargs(batch)
172-
args, kwargs = module.preprocess_input(args, kwargs)
173172

174173
# Start input distribution.
175174
module_ctx = module.create_context()
@@ -405,8 +404,6 @@ def _rewrite_model( # noqa C901
405404
logger.info(f"Module '{node.target}' will be pipelined")
406405
child = sharded_modules[node.target]
407406
original_forwards.append(child.forward)
408-
# Set pipelining flag on the child module
409-
child.is_pipelined = True
410407
# pyre-ignore[8] Incompatible attribute type
411408
child.forward = pipelined_forward(
412409
node.target,

0 commit comments

Comments
 (0)