@@ -175,12 +175,104 @@ def chemical_validity(pred):
175175 validity .append (1 if mol else 0 )
176176
177177 return torch .tensor (validity , dtype = torch .float , device = pred .device )
178+
179+
180+ @R .register ("metrics.variadic_auroc" )
181+ def variadic_area_under_roc (pred , target , size ):
182+ """
183+ Area under receiver operating characteristic curve (ROC) for sets with variadic sizes.
184+
185+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
186+
187+ Parameters:
188+ pred (Tensor): prediction of shape :math:`(B,)`
189+ target (Tensor): target of shape :math:`(B,)`.
190+ size (Tensor): size of sets of shape :math:`(N,)`
191+ """
192+ index2graph = functional ._size_to_index (size )
193+ _ , order = functional .variadic_sort (pred , size , descending = True )
194+ cum_size = (size .cumsum (0 ) - size )[index2graph ]
195+ target = target [order + cum_size ]
196+ total_hit = functional .variadic_sum (target , size )
197+ total_hit = total_hit .cumsum (0 ) - total_hit
198+ hit = target .cumsum (0 ) - total_hit [index2graph ]
199+ hit = torch .where (target == 0 , hit , torch .zeros_like (hit ))
200+ all = functional .variadic_sum ((target == 0 ).float (), size ) * \
201+ functional .variadic_sum ((target == 1 ).float (), size )
202+ auroc = functional .variadic_sum (hit , size ) / (all + 1e-10 )
203+ return auroc
204+
205+
206+ @R .register ("metrics.variadic_auprc" )
207+ def variadic_area_under_prc (pred , target , size ):
208+ """
209+ Area under precision-recall curve (PRC) for sets with variadic sizes.
210+
211+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
212+
213+ Parameters:
214+ pred (Tensor): prediction of shape :math:`(B,)`
215+ target (Tensor): target of shape :math:`(B,)`.
216+ size (Tensor): size of sets of shape :math:`(N,)`
217+ """
218+ index2graph = functional ._size_to_index (size )
219+ _ , order = functional .variadic_sort (pred , size , descending = True )
220+ cum_size = (size .cumsum (0 ) - size )[index2graph ]
221+ target = target [order + cum_size ]
222+ total_hit = functional .variadic_sum (target , size )
223+ total_hit = total_hit .cumsum (0 ) - total_hit
224+ hit = target .cumsum (0 ) - total_hit [index2graph ]
225+ total = torch .ones_like (target ).cumsum (0 ) - (size .cumsum (0 ) - size )[index2graph ]
226+ precision = hit / total
227+ precision = torch .where (target == 1 , precision , torch .zeros_like (precision ))
228+ auprc = functional .variadic_sum (precision , size ) / \
229+ (functional .variadic_sum ((target == 1 ).float (), size ) + 1e-10 )
230+ return auprc
231+
232+
233+ @R .register ("metrics.f1_max" )
234+ def f1_max (pred , target ):
235+ """
236+ F1 score with the optimal threshold.
237+
238+ This function first enumerates all possible thresholds for deciding positive and negative
239+ samples, and then pick the threshold with the maximal F1 score.
240+
241+ Parameters:
242+ pred (Tensor): predictions of shape :math:`(B, N)`
243+ target (Tensor): binary targets of shape :math:`(B, N)`
244+ """
245+ order = pred .argsort (descending = True , dim = 1 )
246+ target = target .gather (1 , order )
247+ precision = target .cumsum (1 ) / torch .ones_like (target ).cumsum (1 )
248+ recall = target .cumsum (1 ) / (target .sum (1 , keepdim = True ) + 1e-10 )
249+ is_start = torch .zeros_like (target ).bool ()
250+ is_start [:, 0 ] = 1
251+ is_start = torch .scatter (is_start , 1 , order , is_start )
252+
253+ all_order = pred .flatten ().argsort (descending = True )
254+ order = order + torch .arange (order .shape [0 ], device = order .device ).unsqueeze (1 ) * order .shape [1 ]
255+ order = order .flatten ()
256+ inv_order = torch .zeros_like (order )
257+ inv_order [order ] = torch .arange (order .shape [0 ], device = order .device )
258+ is_start = is_start .flatten ()[all_order ]
259+ all_order = inv_order [all_order ]
260+ precision = precision .flatten ()
261+ recall = recall .flatten ()
262+ all_precision = precision [all_order ] - \
263+ torch .where (is_start , torch .zeros_like (precision ), precision [all_order - 1 ])
264+ all_precision = all_precision .cumsum (0 ) / is_start .cumsum (0 )
265+ all_recall = recall [all_order ] - \
266+ torch .where (is_start , torch .zeros_like (recall ), recall [all_order - 1 ])
267+ all_recall = all_recall .cumsum (0 ) / pred .shape [0 ]
268+ all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10 )
269+ return all_f1 .max ()
178270
179271
180272@R .register ("metrics.accuracy" )
181273def accuracy (pred , target ):
182274 """
183- Compute classification accuracy over sets with equal size .
275+ Classification accuracy.
184276
185277 Suppose there are :math:`N` sets and :math:`C` categories.
186278
@@ -191,16 +283,56 @@ def accuracy(pred, target):
191283 return (pred .argmax (dim = - 1 ) == target ).float ().mean ()
192284
193285
286+ @R .register ("metrics.variadic_accuracy" )
287+ def variadic_accuracy (input , target , size ):
288+ """
289+ Classification accuracy for categories with variadic sizes.
290+
291+ Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math:`B`.
292+
293+ Parameters:
294+ input (Tensor): prediction of shape :math:`(B,)`
295+ target (Tensor): target of shape :math:`(N,)`. Each target is a relative index in a sample.
296+ size (Tensor): number of categories of shape :math:`(N,)`
297+ """
298+ index2graph = functional ._size_to_index (size )
299+
300+ input_class = scatter_max (input , index2graph )[1 ]
301+ target_index = target + size .cumsum (0 ) - size
302+ accuracy = (input_class == target_index ).float ()
303+ return accuracy
304+
305+
306+ @R .register ("metrics.variadic_top_precision" )
307+ def variadic_top_precision (pred , target , size , k ):
308+ """
309+ Top-k precision for sets with variadic sizes.
310+
311+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
312+
313+ Parameters:
314+ pred (Tensor): prediction of shape :math:`(B,)`
315+ target (Tensor): target of shape :math:`(B,)`
316+ size (Tensor): size of sets of shape :math:`(N,)`
317+ k (LongTensor): the k in "top-k" for different sets of shape :math:`(N,)`
318+ """
319+ index = functional .variadic_topk (pred , size , k , largest = True )[1 ]
320+ index = index + (size .cumsum (0 ) - size ).repeat_interleave (k )
321+ precision = functional .variadic_sum (target [index ], k ) / k
322+ precision [size < k ] = 0
323+ return precision
324+
325+
194326@R .register ("metrics.mcc" )
195- def matthews_corrcoef (pred , target , eps = 1e-6 ):
327+ def matthews_corrcoef (pred , target ):
196328 """
197- Matthews correlation coefficient between target and prediction .
329+ Matthews correlation coefficient between prediction and target .
198330
199331 Definition follows matthews_corrcoef for K classes in sklearn.
200- For details, see: ' https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef'
332+ For details, see: ` https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef`
201333
202334 Parameters:
203- pred (Tensor): prediction of shape :math: `(N,)`
335+ pred (Tensor): prediction of shape :math: `(N, K )`
204336 target (Tensor): target of shape :math: `(N,)`
205337 """
206338 num_class = pred .size (- 1 )
@@ -212,14 +344,13 @@ def matthews_corrcoef(pred, target, eps=1e-6):
212344 p = confusion_matrix .sum (dim = 0 )
213345 c = confusion_matrix .trace ()
214346 s = confusion_matrix .sum ()
215- return (c * s - t @ p ) / ((s * s - p @ p ) * (s * s - t @ t ) + eps ).sqrt ()
347+ return (c * s - t @ p ) / ((s * s - p @ p ) * (s * s - t @ t ) + 1e-10 ).sqrt ()
216348
217349
218350@R .register ("metrics.pearsonr" )
219351def pearsonr (pred , target ):
220352 """
221- Pearson correlation between target and prediction.
222- Mimics `scipy.stats.pearsonr`.
353+ Pearson correlation between prediction and target.
223354
224355 Parameters:
225356 pred (Tensor): prediction of shape :math: `(N,)`
@@ -236,10 +367,13 @@ def pearsonr(pred, target):
236367
237368
238369@R .register ("metrics.spearmanr" )
239- def spearmanr (pred , target , eps = 1e-6 ):
370+ def spearmanr (pred , target ):
240371 """
241- Spearman correlation between target and prediction.
242- Implement in PyTorch, but non-diffierentiable. (validation metric only)
372+ Spearman correlation between prediction and target.
373+
374+ .. note::
375+
376+ This function is not differentiable.
243377
244378 Parameters:
245379 pred (Tensor): prediction of shape :math: `(N,)`
@@ -262,25 +396,5 @@ def get_ranking(input):
262396 covariance = (pred * target ).mean () - pred .mean () * target .mean ()
263397 pred_std = pred .std (unbiased = False )
264398 target_std = target .std (unbiased = False )
265- spearmanr = covariance / (pred_std * target_std + eps )
399+ spearmanr = covariance / (pred_std * target_std + 1e-10 )
266400 return spearmanr
267-
268-
269- @R .register ("metrics.variadic_accuracy" )
270- def variadic_accuracy (input , target , size ):
271- """
272- Compute classification accuracy over variadic sizes of categories.
273-
274- Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math:`B`.
275-
276- Parameters:
277- input (Tensor): prediction of shape :math:`(B,)`
278- target (Tensor): target of shape :math:`(N,)`. Each target is a relative index in a sample.
279- size (Tensor): number of categories of shape :math:`(N,)`
280- """
281- index2graph = functional ._size_to_index (size )
282-
283- input_class = scatter_max (input , index2graph )[1 ]
284- target_index = target + size .cumsum (0 ) - size
285- accuracy = (input_class == target_index ).float ()
286- return accuracy
0 commit comments