Skip to content

Commit 14473ce

Browse files
axeisghostmeta-codesync[bot]
authored andcommitted
Enable fused path on Segment NE (meta-pytorch#3498) (meta-pytorch#3499)
Summary: # Motivation T242704386 mentioned that Segment NE is not compatible of FUSED metrics compute, which can bring efficiency win from [post](https://fb.workplace.com/groups/429376538334034/permalink/1474708170467527) # Solution Run group by group NE computation on tensors of all tasks. # Compatibility #thanks to the suggestion from ge0405, we use backward compatible passing of compute mode using bool to avoid cyclic dependencies. Pull Request resolved: meta-pytorch#3499 Reviewed By: iamzainhuda, ge0405 Differential Revision: D85879827 Pulled By: axeisghost fbshipit-source-id: ada8f0cf2105ea4f5ce7bf3ba4719d73deea1d8d
1 parent a5dba57 commit 14473ce

File tree

3 files changed

+195
-26
lines changed

3 files changed

+195
-26
lines changed

torchrec/metrics/rec_metric.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
window_size: int,
136136
compute_on_all_ranks: bool = False,
137137
should_validate_update: bool = False,
138-
fuse_state_tensors: bool = False,
138+
compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION,
139139
process_group: Optional[dist.ProcessGroup] = None,
140140
fused_update_limit: int = 0,
141141
allow_missing_label_with_zero_weight: bool = False,
@@ -144,7 +144,13 @@ def __init__(
144144
) -> None:
145145
metric_init_signature = inspect.signature(Metric.__init__)
146146
if "fuse_state_tensors" in metric_init_signature.parameters:
147-
kwargs["fuse_state_tensors"] = fuse_state_tensors
147+
kwargs["fuse_state_tensors"] = (
148+
True
149+
if compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION
150+
else False
151+
)
152+
if "compute_mode" in metric_init_signature.parameters:
153+
kwargs["compute_mode"] = compute_mode
148154
super().__init__(
149155
process_group=process_group,
150156
*args,
@@ -169,6 +175,7 @@ def __init__(
169175
dist_reduce_fx=lambda x: torch.any(x, dim=0).byte(),
170176
persistent=True,
171177
)
178+
self._compute_mode: RecComputeMode = compute_mode
172179

173180
@staticmethod
174181
def get_window_state_name(state_name: str) -> str:
@@ -428,9 +435,7 @@ def __init__(
428435
window_size=self._window_size,
429436
compute_on_all_ranks=compute_on_all_ranks,
430437
should_validate_update=self._should_validate_update,
431-
fuse_state_tensors=(
432-
compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION
433-
),
438+
compute_mode=compute_mode,
434439
process_group=process_group,
435440
**{**kwargs, **self._get_task_kwargs(task_config)},
436441
)

torchrec/metrics/segmented_ne.py

Lines changed: 150 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,29 @@ def compute_ne(
102102
return result_ne
103103

104104

105+
def compute_ne_fused(
106+
ce_sum: torch.Tensor,
107+
weighted_num_samples: torch.Tensor,
108+
pos_labels: torch.Tensor,
109+
neg_labels: torch.Tensor,
110+
num_groups: int,
111+
n_tasks: int,
112+
eta: float,
113+
) -> torch.Tensor:
114+
# size should be (n_tasks, num_groups)
115+
result_ne = torch.zeros([n_tasks, num_groups])
116+
for group in range(num_groups):
117+
mean_label = pos_labels[:, group] / weighted_num_samples[:, group]
118+
ce_norm = _compute_cross_entropy_norm(
119+
mean_label, pos_labels[:, group], neg_labels[:, group], eta
120+
)
121+
ne = ce_sum[:, group] / ce_norm
122+
result_ne[:, group] = ne
123+
124+
# ne indexed by group - tensor size (num_groups)
125+
return result_ne
126+
127+
105128
def get_segemented_ne_states(
106129
labels: torch.Tensor,
107130
predictions: torch.Tensor,
@@ -111,12 +134,8 @@ def get_segemented_ne_states(
111134
num_groups: int,
112135
) -> Dict[str, torch.Tensor]:
113136
groups = torch.unique(grouping_keys)
114-
cross_entropy, weighted_num_samples, pos_labels, neg_labels = (
115-
torch.zeros(num_groups).to(labels.device),
116-
torch.zeros(num_groups).to(labels.device),
117-
torch.zeros(num_groups).to(labels.device),
118-
torch.zeros(num_groups).to(labels.device),
119-
)
137+
buffer = torch.zeros((4, num_groups), device=labels.device)
138+
cross_entropy, weighted_num_samples, pos_labels, neg_labels = buffer.unbind(0)
120139
for group in groups:
121140
group_mask = grouping_keys == group
122141

@@ -152,6 +171,53 @@ def get_segemented_ne_states(
152171
}
153172

154173

174+
def get_segemented_ne_states_fused(
175+
labels: torch.Tensor,
176+
predictions: torch.Tensor,
177+
weights: torch.Tensor,
178+
grouping_keys: torch.Tensor,
179+
eta: float,
180+
num_groups: int,
181+
n_tasks: int,
182+
) -> Dict[str, torch.Tensor]:
183+
groups = torch.unique(grouping_keys)
184+
buffer = torch.zeros((4, n_tasks, num_groups), device=labels.device)
185+
cross_entropy, weighted_num_samples, pos_labels, neg_labels = buffer.unbind(0)
186+
for group in groups:
187+
group_mask = grouping_keys == group
188+
189+
group_labels = labels[:, group_mask]
190+
group_predictions = predictions[:, group_mask]
191+
group_weights = weights[:, group_mask]
192+
193+
ce_sum_group = torch.sum(
194+
compute_cross_entropy(
195+
labels=group_labels,
196+
predictions=group_predictions,
197+
weights=group_weights,
198+
eta=eta,
199+
),
200+
dim=-1,
201+
)
202+
203+
weighted_num_samples_group = torch.sum(group_weights, dim=-1)
204+
pos_labels_group = torch.sum(group_weights * group_labels, dim=-1)
205+
neg_labels_group = torch.sum(group_weights * (1.0 - group_labels), dim=-1)
206+
207+
cross_entropy[:, group] = ce_sum_group
208+
weighted_num_samples[:, group] = weighted_num_samples_group
209+
pos_labels[:, group] = pos_labels_group
210+
neg_labels[:, group] = neg_labels_group
211+
212+
# tensor size for each value is (num_groups)
213+
return {
214+
"cross_entropy_sum": cross_entropy,
215+
"weighted_num_samples": weighted_num_samples,
216+
"pos_labels": pos_labels,
217+
"neg_labels": neg_labels,
218+
}
219+
220+
155221
def _state_reduction_sum(state: torch.Tensor) -> torch.Tensor:
156222
return state.sum(dim=0)
157223

@@ -251,21 +317,91 @@ def update(
251317
)
252318

253319
grouping_keys = kwargs["required_inputs"][self._grouping_keys]
254-
states = get_segemented_ne_states(
255-
labels,
256-
predictions,
257-
weights,
258-
grouping_keys,
259-
eta=self.eta,
260-
num_groups=self._num_groups,
261-
)
320+
# When labels is 2D, we're in a fused mode (either FUSED_TASKS_COMPUTATION or FUSED_TASKS_AND_STATES_COMPUTATION)
321+
# The states update and NE computation need to be done differently.
322+
# On fused path, we need to group all tasks together to compute NE and update states for all tasks in one tensor.
323+
if (
324+
self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION
325+
or self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION
326+
):
327+
states = get_segemented_ne_states_fused(
328+
labels,
329+
predictions,
330+
weights,
331+
grouping_keys,
332+
eta=self.eta,
333+
num_groups=self._num_groups,
334+
n_tasks=self._n_tasks,
335+
)
336+
else:
337+
states = get_segemented_ne_states(
338+
labels,
339+
predictions,
340+
weights,
341+
grouping_keys,
342+
eta=self.eta,
343+
num_groups=self._num_groups,
344+
)
262345

263346
for state_name, state_value in states.items():
264347
state = getattr(self, state_name)
265348
state += state_value
266349

350+
def _compute_fused(self) -> List[MetricComputationReport]:
351+
reports = []
352+
computed_ne = compute_ne_fused(
353+
# pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
354+
self.cross_entropy_sum,
355+
# pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
356+
self.weighted_num_samples,
357+
# pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
358+
self.pos_labels,
359+
# pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
360+
self.neg_labels,
361+
num_groups=self._num_groups,
362+
n_tasks=self._n_tasks,
363+
eta=self.eta,
364+
)
365+
for group in range(self._num_groups):
366+
reports.append(
367+
MetricComputationReport(
368+
name=MetricName.SEGMENTED_NE,
369+
metric_prefix=MetricPrefix.LIFETIME,
370+
value=computed_ne[:, group],
371+
description="_" + str(group),
372+
),
373+
)
374+
375+
if self._include_logloss:
376+
log_loss_groups = compute_logloss(
377+
# pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
378+
self.cross_entropy_sum,
379+
# pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
380+
self.pos_labels,
381+
# pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
382+
self.neg_labels,
383+
eta=self.eta,
384+
)
385+
for group in range(self._num_groups):
386+
reports.append(
387+
MetricComputationReport(
388+
name=MetricName.LOG_LOSS,
389+
metric_prefix=MetricPrefix.LIFETIME,
390+
value=log_loss_groups[:, group],
391+
description="_" + str(group),
392+
)
393+
)
394+
395+
return reports
396+
267397
def _compute(self) -> List[MetricComputationReport]:
268398
reports = []
399+
if (
400+
self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION
401+
or self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION
402+
):
403+
return self._compute_fused()
404+
269405
computed_ne = compute_ne(
270406
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS...
271407
self.cross_entropy_sum[0],
@@ -349,8 +485,3 @@ def __init__(
349485
else:
350486
# pyre-ignore[6]
351487
self._required_inputs.add(kwargs["grouping_keys"])
352-
if self._compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION:
353-
logging.warning(
354-
f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support {self._namespace} yet "
355-
"because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect."
356-
)

torchrec/metrics/tests/test_segmented_ne.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
from torch import no_grad
15+
from torchrec.metrics.metrics_config import RecComputeMode
1516
from torchrec.metrics.rec_metric import RecTaskInfo
1617
from torchrec.metrics.segmented_ne import SegmentedNEMetric
1718

@@ -33,6 +34,7 @@ def _test_segemented_ne_helper(
3334
grouping_keys: torch.Tensor,
3435
grouping_key_tensor_name: str = "grouping_keys",
3536
cast_keys_to_int: bool = False,
37+
compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION,
3638
) -> None:
3739
num_task = labels.shape[0]
3840
batch_size = labels.shape[0]
@@ -70,6 +72,7 @@ def _test_segemented_ne_helper(
7072
grouping_keys=grouping_key_tensor_name,
7173
# pyre-ignore
7274
cast_keys_to_int=cast_keys_to_int,
75+
compute_mode=compute_mode,
7376
)
7477
ne.update(**inputs)
7578
actual_ne = ne.compute()
@@ -95,9 +98,39 @@ def test_grouped_ne(self) -> None:
9598
test_data = generate_model_outputs_cases()
9699
for inputs in test_data:
97100
try:
98-
self._test_segemented_ne_helper(**inputs)
101+
self._test_segemented_ne_helper(
102+
**inputs,
103+
compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
104+
)
105+
except AssertionError:
106+
print(
107+
"Assertion error caught with data set in UNFUSED_TASKS_COMPUTATION mode",
108+
inputs,
109+
)
110+
raise
111+
112+
try:
113+
self._test_segemented_ne_helper(
114+
**inputs,
115+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
116+
)
99117
except AssertionError:
100-
print("Assertion error caught with data set ", inputs)
118+
print(
119+
"Assertion error caught with data set in FUSED_TASKS_COMPUTATION mode",
120+
inputs,
121+
)
122+
raise
123+
124+
try:
125+
self._test_segemented_ne_helper(
126+
**inputs,
127+
compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
128+
)
129+
except AssertionError:
130+
print(
131+
"Assertion error caught with data set in FUSED_TASKS_AND_STATES_COMPUTATION mode",
132+
inputs,
133+
)
101134
raise
102135

103136

0 commit comments

Comments
 (0)