@@ -102,6 +102,29 @@ def compute_ne(
102102 return result_ne
103103
104104
105+ def compute_ne_fused (
106+ ce_sum : torch .Tensor ,
107+ weighted_num_samples : torch .Tensor ,
108+ pos_labels : torch .Tensor ,
109+ neg_labels : torch .Tensor ,
110+ num_groups : int ,
111+ n_tasks : int ,
112+ eta : float ,
113+ ) -> torch .Tensor :
114+ # size should be (n_tasks, num_groups)
115+ result_ne = torch .zeros ([n_tasks , num_groups ])
116+ for group in range (num_groups ):
117+ mean_label = pos_labels [:, group ] / weighted_num_samples [:, group ]
118+ ce_norm = _compute_cross_entropy_norm (
119+ mean_label , pos_labels [:, group ], neg_labels [:, group ], eta
120+ )
121+ ne = ce_sum [:, group ] / ce_norm
122+ result_ne [:, group ] = ne
123+
124+ # ne indexed by group - tensor size (num_groups)
125+ return result_ne
126+
127+
105128def get_segemented_ne_states (
106129 labels : torch .Tensor ,
107130 predictions : torch .Tensor ,
@@ -111,12 +134,8 @@ def get_segemented_ne_states(
111134 num_groups : int ,
112135) -> Dict [str , torch .Tensor ]:
113136 groups = torch .unique (grouping_keys )
114- cross_entropy , weighted_num_samples , pos_labels , neg_labels = (
115- torch .zeros (num_groups ).to (labels .device ),
116- torch .zeros (num_groups ).to (labels .device ),
117- torch .zeros (num_groups ).to (labels .device ),
118- torch .zeros (num_groups ).to (labels .device ),
119- )
137+ buffer = torch .zeros ((4 , num_groups ), device = labels .device )
138+ cross_entropy , weighted_num_samples , pos_labels , neg_labels = buffer .unbind (0 )
120139 for group in groups :
121140 group_mask = grouping_keys == group
122141
@@ -152,6 +171,53 @@ def get_segemented_ne_states(
152171 }
153172
154173
174+ def get_segemented_ne_states_fused (
175+ labels : torch .Tensor ,
176+ predictions : torch .Tensor ,
177+ weights : torch .Tensor ,
178+ grouping_keys : torch .Tensor ,
179+ eta : float ,
180+ num_groups : int ,
181+ n_tasks : int ,
182+ ) -> Dict [str , torch .Tensor ]:
183+ groups = torch .unique (grouping_keys )
184+ buffer = torch .zeros ((4 , n_tasks , num_groups ), device = labels .device )
185+ cross_entropy , weighted_num_samples , pos_labels , neg_labels = buffer .unbind (0 )
186+ for group in groups :
187+ group_mask = grouping_keys == group
188+
189+ group_labels = labels [:, group_mask ]
190+ group_predictions = predictions [:, group_mask ]
191+ group_weights = weights [:, group_mask ]
192+
193+ ce_sum_group = torch .sum (
194+ compute_cross_entropy (
195+ labels = group_labels ,
196+ predictions = group_predictions ,
197+ weights = group_weights ,
198+ eta = eta ,
199+ ),
200+ dim = - 1 ,
201+ )
202+
203+ weighted_num_samples_group = torch .sum (group_weights , dim = - 1 )
204+ pos_labels_group = torch .sum (group_weights * group_labels , dim = - 1 )
205+ neg_labels_group = torch .sum (group_weights * (1.0 - group_labels ), dim = - 1 )
206+
207+ cross_entropy [:, group ] = ce_sum_group
208+ weighted_num_samples [:, group ] = weighted_num_samples_group
209+ pos_labels [:, group ] = pos_labels_group
210+ neg_labels [:, group ] = neg_labels_group
211+
212+ # tensor size for each value is (num_groups)
213+ return {
214+ "cross_entropy_sum" : cross_entropy ,
215+ "weighted_num_samples" : weighted_num_samples ,
216+ "pos_labels" : pos_labels ,
217+ "neg_labels" : neg_labels ,
218+ }
219+
220+
155221def _state_reduction_sum (state : torch .Tensor ) -> torch .Tensor :
156222 return state .sum (dim = 0 )
157223
@@ -251,21 +317,91 @@ def update(
251317 )
252318
253319 grouping_keys = kwargs ["required_inputs" ][self ._grouping_keys ]
254- states = get_segemented_ne_states (
255- labels ,
256- predictions ,
257- weights ,
258- grouping_keys ,
259- eta = self .eta ,
260- num_groups = self ._num_groups ,
261- )
320+ # When labels is 2D, we're in a fused mode (either FUSED_TASKS_COMPUTATION or FUSED_TASKS_AND_STATES_COMPUTATION)
321+ # The states update and NE computation need to be done differently.
322+ # On fused path, we need to group all tasks together to compute NE and update states for all tasks in one tensor.
323+ if (
324+ self ._compute_mode == RecComputeMode .FUSED_TASKS_COMPUTATION
325+ or self ._compute_mode == RecComputeMode .FUSED_TASKS_AND_STATES_COMPUTATION
326+ ):
327+ states = get_segemented_ne_states_fused (
328+ labels ,
329+ predictions ,
330+ weights ,
331+ grouping_keys ,
332+ eta = self .eta ,
333+ num_groups = self ._num_groups ,
334+ n_tasks = self ._n_tasks ,
335+ )
336+ else :
337+ states = get_segemented_ne_states (
338+ labels ,
339+ predictions ,
340+ weights ,
341+ grouping_keys ,
342+ eta = self .eta ,
343+ num_groups = self ._num_groups ,
344+ )
262345
263346 for state_name , state_value in states .items ():
264347 state = getattr (self , state_name )
265348 state += state_value
266349
350+ def _compute_fused (self ) -> List [MetricComputationReport ]:
351+ reports = []
352+ computed_ne = compute_ne_fused (
353+ # pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
354+ self .cross_entropy_sum ,
355+ # pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
356+ self .weighted_num_samples ,
357+ # pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
358+ self .pos_labels ,
359+ # pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
360+ self .neg_labels ,
361+ num_groups = self ._num_groups ,
362+ n_tasks = self ._n_tasks ,
363+ eta = self .eta ,
364+ )
365+ for group in range (self ._num_groups ):
366+ reports .append (
367+ MetricComputationReport (
368+ name = MetricName .SEGMENTED_NE ,
369+ metric_prefix = MetricPrefix .LIFETIME ,
370+ value = computed_ne [:, group ],
371+ description = "_" + str (group ),
372+ ),
373+ )
374+
375+ if self ._include_logloss :
376+ log_loss_groups = compute_logloss (
377+ # pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
378+ self .cross_entropy_sum ,
379+ # pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
380+ self .pos_labels ,
381+ # pyre-fixme[6]: `In call `compute_ne_fused`, for 1st positional argument, expected `Tensor` but got `Union[Tensor, Module]`
382+ self .neg_labels ,
383+ eta = self .eta ,
384+ )
385+ for group in range (self ._num_groups ):
386+ reports .append (
387+ MetricComputationReport (
388+ name = MetricName .LOG_LOSS ,
389+ metric_prefix = MetricPrefix .LIFETIME ,
390+ value = log_loss_groups [:, group ],
391+ description = "_" + str (group ),
392+ )
393+ )
394+
395+ return reports
396+
267397 def _compute (self ) -> List [MetricComputationReport ]:
268398 reports = []
399+ if (
400+ self ._compute_mode == RecComputeMode .FUSED_TASKS_COMPUTATION
401+ or self ._compute_mode == RecComputeMode .FUSED_TASKS_AND_STATES_COMPUTATION
402+ ):
403+ return self ._compute_fused ()
404+
269405 computed_ne = compute_ne (
270406 # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedS...
271407 self .cross_entropy_sum [0 ],
@@ -349,8 +485,3 @@ def __init__(
349485 else :
350486 # pyre-ignore[6]
351487 self ._required_inputs .add (kwargs ["grouping_keys" ])
352- if self ._compute_mode == RecComputeMode .FUSED_TASKS_AND_STATES_COMPUTATION :
353- logging .warning (
354- f"compute_mode FUSED_TASKS_AND_STATES_COMPUTATION can't support { self ._namespace } yet "
355- "because its states are not 1D Tensors. Only FUSED_TASKS_COMPUTATION will take effect."
356- )
0 commit comments