Skip to content

Commit 5f9ce1c

Browse files
committed
add metrics
1 parent 4b4b30a commit 5f9ce1c

File tree

3 files changed

+98
-9
lines changed

3 files changed

+98
-9
lines changed

torchdrug/metrics/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from .metric import area_under_roc, area_under_prc, r2, QED, logP, penalized_logP, SA, chemical_validity, \
2-
variadic_accuracy
2+
accuracy, matthews_corrcoef, pearsonr, spearmanr, variadic_accuracy
33

44
# alias
55
AUROC = area_under_roc
66
AUPRC = area_under_prc
77

88
__all__ = [
99
"area_under_roc", "area_under_prc", "r2", "QED", "logP", "penalized_logP", "SA", "chemical_validity",
10+
"accuracy", "matthews_corrcoef", "pearsonr", "spearmanr",
1011
"variadic_accuracy",
1112
"AUROC", "AUPRC",
12-
]
13+
]

torchdrug/metrics/metric.py

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import os
2-
import sys
3-
41
import torch
52
from torch.nn import functional as F
6-
from torch_scatter import scatter_max
3+
from torch_scatter import scatter_add, scatter_mean, scatter_max
74
import networkx as nx
85
from rdkit import Chem
9-
from rdkit.Chem import RDConfig, Descriptors
6+
from rdkit.Chem import Descriptors
107

118
from torchdrug import utils
129
from torchdrug.layers import functional
@@ -178,13 +175,103 @@ def chemical_validity(pred):
178175
validity.append(1 if mol else 0)
179176

180177
return torch.tensor(validity, dtype=torch.float, device=pred.device)
178+
179+
180+
@R.register("metrics.accuracy")
181+
def accuracy(pred, target):
182+
"""
183+
Compute classification accuracy over sets with equal size.
184+
185+
Suppose there are :math:`N` sets and :math:`C` categories.
186+
187+
Parameters:
188+
pred (Tensor): prediction of shape :math:`(N, C)`
189+
target (Tensor): target of shape :math:`(N,)`
190+
"""
191+
return (pred.argmax(dim=-1) == target).float().mean()
192+
193+
194+
@R.register("metrics.mcc")
195+
def matthews_corrcoef(pred, target, eps=1e-6):
196+
"""
197+
Matthews correlation coefficient between target and prediction.
198+
199+
Definition follows matthews_corrcoef for K classes in sklearn.
200+
For details, see: 'https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef'
201+
202+
Parameters:
203+
pred (Tensor): prediction of shape :math: `(N,)`
204+
target (Tensor): target of shape :math: `(N,)`
205+
"""
206+
num_class = pred.size(-1)
207+
pred = pred.argmax(-1)
208+
ones = torch.ones(len(target), device=pred.device)
209+
confusion_matrix = scatter_add(ones, target * num_class + pred, dim=0, dim_size=num_class ** 2)
210+
confusion_matrix = confusion_matrix.view(num_class, num_class)
211+
t = confusion_matrix.sum(dim=1)
212+
p = confusion_matrix.sum(dim=0)
213+
c = confusion_matrix.trace()
214+
s = confusion_matrix.sum()
215+
return (c * s - t @ p) / ((s * s - p @ p) * (s * s - t @ t) + eps).sqrt()
216+
217+
218+
@R.register("metrics.pearsonr")
219+
def pearsonr(pred, target):
220+
"""
221+
Pearson correlation between target and prediction.
222+
Mimics `scipy.stats.pearsonr`.
223+
224+
Parameters:
225+
pred (Tensor): prediction of shape :math: `(N,)`
226+
target (Tensor): target of shape :math: `(N,)`
227+
"""
228+
pred_mean = pred.float().mean()
229+
target_mean = target.float().mean()
230+
pred_centered = pred - pred_mean
231+
target_centered = target - target_mean
232+
pred_normalized = pred_centered / pred_centered.norm(2)
233+
target_normalized = target_centered / target_centered.norm(2)
234+
pearsonr = pred_normalized @ target_normalized
235+
return pearsonr
236+
237+
238+
@R.register("metrics.spearmanr")
239+
def spearmanr(pred, target, eps=1e-6):
240+
"""
241+
Spearman correlation between target and prediction.
242+
Implement in PyTorch, but non-diffierentiable. (validation metric only)
243+
244+
Parameters:
245+
pred (Tensor): prediction of shape :math: `(N,)`
246+
target (Tensor): target of shape :math: `(N,)`
247+
"""
248+
249+
def get_ranking(input):
250+
input_set, input_inverse = input.unique(return_inverse=True)
251+
order = input_inverse.argsort()
252+
ranking = torch.zeros(len(input_inverse), device=input.device)
253+
ranking[order] = torch.arange(1, len(input) + 1, dtype=torch.float, device=input.device)
254+
255+
# for elements that have the same value, replace their rankings with the mean of their rankings
256+
mean_ranking = scatter_mean(ranking, input_inverse, dim=0, dim_size=len(input_set))
257+
ranking = mean_ranking[input_inverse]
258+
return ranking
259+
260+
pred = get_ranking(pred)
261+
target = get_ranking(target)
262+
covariance = (pred * target).mean() - pred.mean() * target.mean()
263+
pred_std = pred.std(unbiased=False)
264+
target_std = target.std(unbiased=False)
265+
spearmanr = covariance / (pred_std * target_std + eps)
266+
return spearmanr
181267

182268

269+
@R.register("metrics.variadic_accuracy")
183270
def variadic_accuracy(input, target, size):
184271
"""
185272
Compute classification accuracy over variadic sizes of categories.
186273
187-
Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math`B`.
274+
Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math:`B`.
188275
189276
Parameters:
190277
input (Tensor): prediction of shape :math:`(B,)`
@@ -196,4 +283,4 @@ def variadic_accuracy(input, target, size):
196283
input_class = scatter_max(input, index2graph)[1]
197284
target_index = target + size.cumsum(0) - size
198285
accuracy = (input_class == target_index).float()
199-
return accuracy
286+
return accuracy

torchdrug/tasks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"mse": "mean squared error",
2020
"rmse": "root mean squared error",
2121
"acc": "accuracy",
22+
"mcc": "matthews correlation coefficient",
2223
}
2324

2425

0 commit comments

Comments
 (0)