1- import os
2- import sys
3-
41import torch
52from torch .nn import functional as F
6- from torch_scatter import scatter_max
3+ from torch_scatter import scatter_add , scatter_mean , scatter_max
74import networkx as nx
85from rdkit import Chem
9- from rdkit .Chem import RDConfig , Descriptors
6+ from rdkit .Chem import Descriptors
107
118from torchdrug import utils
129from torchdrug .layers import functional
@@ -178,13 +175,103 @@ def chemical_validity(pred):
178175 validity .append (1 if mol else 0 )
179176
180177 return torch .tensor (validity , dtype = torch .float , device = pred .device )
178+
179+
180+ @R .register ("metrics.accuracy" )
181+ def accuracy (pred , target ):
182+ """
183+ Compute classification accuracy over sets with equal size.
184+
185+ Suppose there are :math:`N` sets and :math:`C` categories.
186+
187+ Parameters:
188+ pred (Tensor): prediction of shape :math:`(N, C)`
189+ target (Tensor): target of shape :math:`(N,)`
190+ """
191+ return (pred .argmax (dim = - 1 ) == target ).float ().mean ()
192+
193+
194+ @R .register ("metrics.mcc" )
195+ def matthews_corrcoef (pred , target , eps = 1e-6 ):
196+ """
197+ Matthews correlation coefficient between target and prediction.
198+
199+ Definition follows matthews_corrcoef for K classes in sklearn.
200+ For details, see: 'https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef'
201+
202+ Parameters:
203+ pred (Tensor): prediction of shape :math: `(N,)`
204+ target (Tensor): target of shape :math: `(N,)`
205+ """
206+ num_class = pred .size (- 1 )
207+ pred = pred .argmax (- 1 )
208+ ones = torch .ones (len (target ), device = pred .device )
209+ confusion_matrix = scatter_add (ones , target * num_class + pred , dim = 0 , dim_size = num_class ** 2 )
210+ confusion_matrix = confusion_matrix .view (num_class , num_class )
211+ t = confusion_matrix .sum (dim = 1 )
212+ p = confusion_matrix .sum (dim = 0 )
213+ c = confusion_matrix .trace ()
214+ s = confusion_matrix .sum ()
215+ return (c * s - t @ p ) / ((s * s - p @ p ) * (s * s - t @ t ) + eps ).sqrt ()
216+
217+
218+ @R .register ("metrics.pearsonr" )
219+ def pearsonr (pred , target ):
220+ """
221+ Pearson correlation between target and prediction.
222+ Mimics `scipy.stats.pearsonr`.
223+
224+ Parameters:
225+ pred (Tensor): prediction of shape :math: `(N,)`
226+ target (Tensor): target of shape :math: `(N,)`
227+ """
228+ pred_mean = pred .float ().mean ()
229+ target_mean = target .float ().mean ()
230+ pred_centered = pred - pred_mean
231+ target_centered = target - target_mean
232+ pred_normalized = pred_centered / pred_centered .norm (2 )
233+ target_normalized = target_centered / target_centered .norm (2 )
234+ pearsonr = pred_normalized @ target_normalized
235+ return pearsonr
236+
237+
238+ @R .register ("metrics.spearmanr" )
239+ def spearmanr (pred , target , eps = 1e-6 ):
240+ """
241+ Spearman correlation between target and prediction.
242+ Implement in PyTorch, but non-diffierentiable. (validation metric only)
243+
244+ Parameters:
245+ pred (Tensor): prediction of shape :math: `(N,)`
246+ target (Tensor): target of shape :math: `(N,)`
247+ """
248+
249+ def get_ranking (input ):
250+ input_set , input_inverse = input .unique (return_inverse = True )
251+ order = input_inverse .argsort ()
252+ ranking = torch .zeros (len (input_inverse ), device = input .device )
253+ ranking [order ] = torch .arange (1 , len (input ) + 1 , dtype = torch .float , device = input .device )
254+
255+ # for elements that have the same value, replace their rankings with the mean of their rankings
256+ mean_ranking = scatter_mean (ranking , input_inverse , dim = 0 , dim_size = len (input_set ))
257+ ranking = mean_ranking [input_inverse ]
258+ return ranking
259+
260+ pred = get_ranking (pred )
261+ target = get_ranking (target )
262+ covariance = (pred * target ).mean () - pred .mean () * target .mean ()
263+ pred_std = pred .std (unbiased = False )
264+ target_std = target .std (unbiased = False )
265+ spearmanr = covariance / (pred_std * target_std + eps )
266+ return spearmanr
181267
182268
269+ @R .register ("metrics.variadic_accuracy" )
183270def variadic_accuracy (input , target , size ):
184271 """
185272 Compute classification accuracy over variadic sizes of categories.
186273
187- Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math`B`.
274+ Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math: `B`.
188275
189276 Parameters:
190277 input (Tensor): prediction of shape :math:`(B,)`
@@ -196,4 +283,4 @@ def variadic_accuracy(input, target, size):
196283 input_class = scatter_max (input , index2graph )[1 ]
197284 target_index = target + size .cumsum (0 ) - size
198285 accuracy = (input_class == target_index ).float ()
199- return accuracy
286+ return accuracy
0 commit comments