Skip to content

Commit 921f3f3

Browse files
committed
add protein tasks
1 parent 786d238 commit 921f3f3

File tree

10 files changed

+1107
-150
lines changed

10 files changed

+1107
-150
lines changed

torchdrug/metrics/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from .metric import area_under_roc, area_under_prc, r2, QED, logP, penalized_logP, SA, chemical_validity, \
2-
accuracy, matthews_corrcoef, pearsonr, spearmanr, variadic_accuracy
2+
accuracy, variadic_accuracy, matthews_corrcoef, pearsonr, spearmanr, \
3+
variadic_area_under_prc, variadic_area_under_roc, variadic_top_precision, f1_max
34

45
# alias
56
AUROC = area_under_roc
67
AUPRC = area_under_prc
78

89
__all__ = [
910
"area_under_roc", "area_under_prc", "r2", "QED", "logP", "penalized_logP", "SA", "chemical_validity",
10-
"accuracy", "matthews_corrcoef", "pearsonr", "spearmanr",
11-
"variadic_accuracy",
11+
"accuracy", "variadic_accuracy", "matthews_corrcoef", "pearsonr", "spearmanr",
12+
"variadic_area_under_prc", "variadic_area_under_roc", "variadic_top_precision", "f1_max",
1213
"AUROC", "AUPRC",
13-
]
14+
]

torchdrug/metrics/metric.py

Lines changed: 146 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,104 @@ def chemical_validity(pred):
175175
validity.append(1 if mol else 0)
176176

177177
return torch.tensor(validity, dtype=torch.float, device=pred.device)
178+
179+
180+
@R.register("metrics.variadic_auroc")
181+
def variadic_area_under_roc(pred, target, size):
182+
"""
183+
Area under receiver operating characteristic curve (ROC) for sets with variadic sizes.
184+
185+
Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
186+
187+
Parameters:
188+
pred (Tensor): prediction of shape :math:`(B,)`
189+
target (Tensor): target of shape :math:`(B,)`.
190+
size (Tensor): size of sets of shape :math:`(N,)`
191+
"""
192+
index2graph = functional._size_to_index(size)
193+
_, order = functional.variadic_sort(pred, size, descending=True)
194+
cum_size = (size.cumsum(0) - size)[index2graph]
195+
target = target[order + cum_size]
196+
total_hit = functional.variadic_sum(target, size)
197+
total_hit = total_hit.cumsum(0) - total_hit
198+
hit = target.cumsum(0) - total_hit[index2graph]
199+
hit = torch.where(target == 0, hit, torch.zeros_like(hit))
200+
all = functional.variadic_sum((target == 0).float(), size) * \
201+
functional.variadic_sum((target == 1).float(), size)
202+
auroc = functional.variadic_sum(hit, size) / (all + 1e-10)
203+
return auroc
204+
205+
206+
@R.register("metrics.variadic_auprc")
207+
def variadic_area_under_prc(pred, target, size):
208+
"""
209+
Area under precision-recall curve (PRC) for sets with variadic sizes.
210+
211+
Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
212+
213+
Parameters:
214+
pred (Tensor): prediction of shape :math:`(B,)`
215+
target (Tensor): target of shape :math:`(B,)`.
216+
size (Tensor): size of sets of shape :math:`(N,)`
217+
"""
218+
index2graph = functional._size_to_index(size)
219+
_, order = functional.variadic_sort(pred, size, descending=True)
220+
cum_size = (size.cumsum(0) - size)[index2graph]
221+
target = target[order + cum_size]
222+
total_hit = functional.variadic_sum(target, size)
223+
total_hit = total_hit.cumsum(0) - total_hit
224+
hit = target.cumsum(0) - total_hit[index2graph]
225+
total = torch.ones_like(target).cumsum(0) - (size.cumsum(0) - size)[index2graph]
226+
precision = hit / total
227+
precision = torch.where(target == 1, precision, torch.zeros_like(precision))
228+
auprc = functional.variadic_sum(precision, size) / \
229+
(functional.variadic_sum((target == 1).float(), size) + 1e-10)
230+
return auprc
231+
232+
233+
@R.register("metrics.f1_max")
234+
def f1_max(pred, target):
235+
"""
236+
F1 score with the optimal threshold.
237+
238+
This function first enumerates all possible thresholds for deciding positive and negative
239+
samples, and then pick the threshold with the maximal F1 score.
240+
241+
Parameters:
242+
pred (Tensor): predictions of shape :math:`(B, N)`
243+
target (Tensor): binary targets of shape :math:`(B, N)`
244+
"""
245+
order = pred.argsort(descending=True, dim=1)
246+
target = target.gather(1, order)
247+
precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
248+
recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)
249+
is_start = torch.zeros_like(target).bool()
250+
is_start[:, 0] = 1
251+
is_start = torch.scatter(is_start, 1, order, is_start)
252+
253+
all_order = pred.flatten().argsort(descending=True)
254+
order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1]
255+
order = order.flatten()
256+
inv_order = torch.zeros_like(order)
257+
inv_order[order] = torch.arange(order.shape[0], device=order.device)
258+
is_start = is_start.flatten()[all_order]
259+
all_order = inv_order[all_order]
260+
precision = precision.flatten()
261+
recall = recall.flatten()
262+
all_precision = precision[all_order] - \
263+
torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1])
264+
all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
265+
all_recall = recall[all_order] - \
266+
torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1])
267+
all_recall = all_recall.cumsum(0) / pred.shape[0]
268+
all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10)
269+
return all_f1.max()
178270

179271

180272
@R.register("metrics.accuracy")
181273
def accuracy(pred, target):
182274
"""
183-
Compute classification accuracy over sets with equal size.
275+
Classification accuracy.
184276
185277
Suppose there are :math:`N` sets and :math:`C` categories.
186278
@@ -191,16 +283,56 @@ def accuracy(pred, target):
191283
return (pred.argmax(dim=-1) == target).float().mean()
192284

193285

286+
@R.register("metrics.variadic_accuracy")
287+
def variadic_accuracy(input, target, size):
288+
"""
289+
Classification accuracy for categories with variadic sizes.
290+
291+
Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math:`B`.
292+
293+
Parameters:
294+
input (Tensor): prediction of shape :math:`(B,)`
295+
target (Tensor): target of shape :math:`(N,)`. Each target is a relative index in a sample.
296+
size (Tensor): number of categories of shape :math:`(N,)`
297+
"""
298+
index2graph = functional._size_to_index(size)
299+
300+
input_class = scatter_max(input, index2graph)[1]
301+
target_index = target + size.cumsum(0) - size
302+
accuracy = (input_class == target_index).float()
303+
return accuracy
304+
305+
306+
@R.register("metrics.variadic_top_precision")
307+
def variadic_top_precision(pred, target, size, k):
308+
"""
309+
Top-k precision for sets with variadic sizes.
310+
311+
Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
312+
313+
Parameters:
314+
pred (Tensor): prediction of shape :math:`(B,)`
315+
target (Tensor): target of shape :math:`(B,)`
316+
size (Tensor): size of sets of shape :math:`(N,)`
317+
k (LongTensor): the k in "top-k" for different sets of shape :math:`(N,)`
318+
"""
319+
index = functional.variadic_topk(pred, size, k, largest=True)[1]
320+
index = index + (size.cumsum(0) - size).repeat_interleave(k)
321+
precision = functional.variadic_sum(target[index], k) / k
322+
precision[size < k] = 0
323+
return precision
324+
325+
194326
@R.register("metrics.mcc")
195-
def matthews_corrcoef(pred, target, eps=1e-6):
327+
def matthews_corrcoef(pred, target):
196328
"""
197-
Matthews correlation coefficient between target and prediction.
329+
Matthews correlation coefficient between prediction and target.
198330
199331
Definition follows matthews_corrcoef for K classes in sklearn.
200-
For details, see: 'https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef'
332+
For details, see: `https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef`
201333
202334
Parameters:
203-
pred (Tensor): prediction of shape :math: `(N,)`
335+
pred (Tensor): prediction of shape :math: `(N, K)`
204336
target (Tensor): target of shape :math: `(N,)`
205337
"""
206338
num_class = pred.size(-1)
@@ -212,14 +344,13 @@ def matthews_corrcoef(pred, target, eps=1e-6):
212344
p = confusion_matrix.sum(dim=0)
213345
c = confusion_matrix.trace()
214346
s = confusion_matrix.sum()
215-
return (c * s - t @ p) / ((s * s - p @ p) * (s * s - t @ t) + eps).sqrt()
347+
return (c * s - t @ p) / ((s * s - p @ p) * (s * s - t @ t) + 1e-10).sqrt()
216348

217349

218350
@R.register("metrics.pearsonr")
219351
def pearsonr(pred, target):
220352
"""
221-
Pearson correlation between target and prediction.
222-
Mimics `scipy.stats.pearsonr`.
353+
Pearson correlation between prediction and target.
223354
224355
Parameters:
225356
pred (Tensor): prediction of shape :math: `(N,)`
@@ -236,10 +367,13 @@ def pearsonr(pred, target):
236367

237368

238369
@R.register("metrics.spearmanr")
239-
def spearmanr(pred, target, eps=1e-6):
370+
def spearmanr(pred, target):
240371
"""
241-
Spearman correlation between target and prediction.
242-
Implement in PyTorch, but non-diffierentiable. (validation metric only)
372+
Spearman correlation between prediction and target.
373+
374+
.. note::
375+
376+
This function is not differentiable.
243377
244378
Parameters:
245379
pred (Tensor): prediction of shape :math: `(N,)`
@@ -262,25 +396,5 @@ def get_ranking(input):
262396
covariance = (pred * target).mean() - pred.mean() * target.mean()
263397
pred_std = pred.std(unbiased=False)
264398
target_std = target.std(unbiased=False)
265-
spearmanr = covariance / (pred_std * target_std + eps)
399+
spearmanr = covariance / (pred_std * target_std + 1e-10)
266400
return spearmanr
267-
268-
269-
@R.register("metrics.variadic_accuracy")
270-
def variadic_accuracy(input, target, size):
271-
"""
272-
Compute classification accuracy over variadic sizes of categories.
273-
274-
Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math:`B`.
275-
276-
Parameters:
277-
input (Tensor): prediction of shape :math:`(B,)`
278-
target (Tensor): target of shape :math:`(N,)`. Each target is a relative index in a sample.
279-
size (Tensor): number of categories of shape :math:`(N,)`
280-
"""
281-
index2graph = functional._size_to_index(size)
282-
283-
input_class = scatter_max(input, index2graph)[1]
284-
target_index = target + size.cumsum(0) - size
285-
accuracy = (input_class == target_index).float()
286-
return accuracy

torchdrug/tasks/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from .task import Task
22

3-
from .property_prediction import PropertyPrediction, Unsupervised
4-
from .pretrain import EdgePrediction, AttributeMasking, ContextPrediction
3+
from .property_prediction import PropertyPrediction, MultipleBinaryClassification, \
4+
NodePropertyPrediction, InteractionPrediction, Unsupervised
5+
from .pretrain import EdgePrediction, AttributeMasking, ContextPrediction, DistancePrediction, \
6+
AnglePrediction, DihedralPrediction
57
from .generation import AutoregressiveGeneration, GCPNGeneration
68
from .retrosynthesis import CenterIdentification, SynthonCompletion, Retrosynthesis
79
from .reasoning import KnowledgeGraphCompletion
10+
from .contact_prediction import ContactPrediction
811

912

1013
_criterion_name = {
@@ -36,9 +39,12 @@ def _get_metric_name(metric):
3639

3740

3841
__all__ = [
39-
"PropertyPrediction", "Unsupervised",
40-
"EdgePrediction", "AttributeMasking", "ContextPrediction",
42+
"PropertyPrediction", "MultipleBinaryClassification", "NodePropertyPrediction", "InteractionPrediction",
43+
"Unsupervised",
44+
"EdgePrediction", "AttributeMasking", "ContextPrediction", "DistancePrediction", "AnglePrediction",
45+
"DihedralPrediction",
4146
"AutoregressiveGeneration", "GCPNGeneration",
4247
"CenterIdentification", "SynthonCompletion", "Retrosynthesis",
4348
"KnowledgeGraphCompletion",
49+
"ContactPrediction",
4450
]

0 commit comments

Comments
 (0)