diff --git a/torchrec/metrics/auprc.py b/torchrec/metrics/auprc.py index ed99417d2..8c90217da 100644 --- a/torchrec/metrics/auprc.py +++ b/torchrec/metrics/auprc.py @@ -65,9 +65,16 @@ def _compute_auprc_helper( precision = torch.cat([precision, precision.new_ones(1)]) recall = torch.cat([recall, recall.new_zeros(1)]) - # If recalls are NaNs, set NaNs to 1.0s. - if torch.isnan(recall[0]): - recall = torch.nan_to_num(recall, 1.0) + # nan happens with 0.0 / 0.0. For recall's case, this could happen from its right side: + # num_fp is a cumsum and thus 0.0 starts from its left side. But given recall has a flip, + # then those 0.0 goes to right side and thus nan. + # If recalls are NaNs, set NaNs to 0.0s, as append a 0.0 on its right side above. + recall = torch.nan_to_num(recall, 0.0) + + # similar as recall, precision's nan would happen from its right side. + # since we append 1.0 on its right side above, we replace nan by 1.0. + # If any element in precision is Nan, _riemann_integral will return NaN. + precision = torch.nan_to_num(precision, 1.0) auprc = _riemann_integral(recall, precision) return auprc