Skip to content

Commit 786d238

Browse files
committed
add protein models
1 parent 2be5c85 commit 786d238

File tree

19 files changed

+1826
-35
lines changed

19 files changed

+1826
-35
lines changed

torchdrug/layers/__init__.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from .common import MultiLayerPerceptron, GaussianSmearing, MutualInformation, PairNorm, InstanceNorm, Sequential
1+
from .common import MultiLayerPerceptron, GaussianSmearing, MutualInformation, PairNorm, InstanceNorm, Sequential, \
2+
SinusoidalPositionEmbedding
23

4+
from .block import ProteinResNetBlock, SelfAttentionBlock, ProteinBERTBlock
35
from .conv import MessagePassingBase, GraphConv, GraphAttentionConv, RelationalGraphConv, GraphIsomorphismConv, \
4-
NeuralFingerprintConv, ContinuousFilterConv, MessagePassing, ChebyshevConv
6+
NeuralFingerprintConv, ContinuousFilterConv, MessagePassing, ChebyshevConv, GeometricRelationalGraphConv
57
from .pool import DiffPool, MinCutPool
6-
from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort
8+
from .readout import MeanReadout, SumReadout, MaxReadout, AttentionReadout, Softmax, Set2Set, Sort
79
from .flow import ConditionalFlow
810
from .sampler import NodeSampler, EdgeSampler
11+
from .geometry import GraphConstruction, SpatialLineGraph
912
from . import distribution, functional
1013

1114
# alias
@@ -20,12 +23,15 @@
2023

2124
__all__ = [
2225
"MultiLayerPerceptron", "GaussianSmearing", "MutualInformation", "PairNorm", "InstanceNorm", "Sequential",
26+
"SinusoidalPositionEmbedding",
2327
"MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv",
24-
"NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv",
28+
"NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv", "GeometricRelationalGraphConv",
2529
"DiffPool", "MinCutPool",
26-
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort",
30+
"MeanReadout", "SumReadout", "MaxReadout", "AttentionReadout", "Softmax", "Set2Set", "Sort",
2731
"ConditionalFlow",
2832
"NodeSampler", "EdgeSampler",
33+
"GraphConstruction", "SpatialLineGraph",
2934
"distribution", "functional",
3035
"MLP", "RBF", "GCNConv", "RGCNConv", "GINConv", "NFPConv", "CFConv", "MPConv",
36+
"ProteinResNetBlock", "SelfAttentionBlock", "ProteinBERTBlock",
3137
]

torchdrug/layers/block.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from torch import nn
2+
from torch.nn import functional as F
3+
4+
from torchdrug import layers
5+
6+
7+
class ProteinResNetBlock(nn.Module):
8+
"""
9+
Convolutional block with residual connection from `Deep Residual Learning for Image Recognition`_.
10+
11+
.. _Deep Residual Learning for Image Recognition:
12+
https://arxiv.org/pdf/1512.03385.pdf
13+
14+
Parameters:
15+
input_dim (int): input dimension
16+
output_dim (int): output dimension
17+
kernel_size (int, optional): size of convolutional kernel
18+
stride (int, optional): stride of convolution
19+
padding (int, optional): padding added to both sides of the input
20+
activation (str or function, optional): activation function
21+
"""
22+
23+
def __init__(self, input_dim, output_dim, kernel_size=3, stride=1, padding=1, activation="gelu"):
24+
super(ProteinResNetBlock, self).__init__()
25+
self.input_dim = input_dim
26+
self.output_dim = output_dim
27+
28+
if isinstance(activation, str):
29+
self.activation = getattr(F, activation)
30+
else:
31+
self.activation = activation
32+
33+
self.conv1 = nn.Conv1d(input_dim, output_dim, kernel_size, stride, padding, bias=False)
34+
self.layer_norm1 = nn.LayerNorm(output_dim)
35+
self.conv2 = nn.Conv1d(output_dim, output_dim, kernel_size, stride, padding, bias=False)
36+
self.layer_norm2 = nn.LayerNorm(output_dim)
37+
38+
def forward(self, input, mask):
39+
"""
40+
Perform 1D convolutions over the input.
41+
42+
Parameters:
43+
input (Tensor): input representations of shape `(..., length, dim)`
44+
mask (Tensor): bool mask of shape `(..., length, dim)`
45+
"""
46+
identity = input
47+
48+
input = input * mask # (B, L, d)
49+
out = self.conv1(input.transpose(1, 2)).transpose(1, 2)
50+
out = self.layer_norm1(out)
51+
out = self.activation(out)
52+
53+
out = out * mask
54+
out = self.conv2(out.transpose(1, 2)).transpose(1, 2)
55+
out = self.layer_norm2(out)
56+
57+
out += identity
58+
out = self.activation(out)
59+
60+
return out
61+
62+
63+
class SelfAttentionBlock(nn.Module):
64+
"""
65+
Multi-head self-attention block from
66+
`Attention Is All You Need`_.
67+
68+
.. _Attention Is All You Need:
69+
https://arxiv.org/pdf/1706.03762.pdf
70+
71+
Parameters:
72+
hidden_dim (int): hidden dimension
73+
num_heads (int): number of attention heads
74+
dropout (float, optional): dropout ratio of attention maps
75+
"""
76+
77+
def __init__(self, hidden_dim, num_heads, dropout=0.0):
78+
super(SelfAttentionBlock, self).__init__()
79+
if hidden_dim % num_heads != 0:
80+
raise ValueError(
81+
"The hidden size (%d) is not a multiple of the number of attention "
82+
"heads (%d)" % (hidden_dim, num_heads))
83+
self.hidden_dim = hidden_dim
84+
self.num_heads = num_heads
85+
self.head_size = hidden_dim // num_heads
86+
87+
self.query = nn.Linear(hidden_dim, hidden_dim)
88+
self.key = nn.Linear(hidden_dim, hidden_dim)
89+
self.value = nn.Linear(hidden_dim, hidden_dim)
90+
91+
self.attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)
92+
93+
def forward(self, input, mask):
94+
"""
95+
Perform self attention over the input.
96+
97+
Parameters:
98+
input (Tensor): input representations of shape `(..., length, dim)`
99+
mask (Tensor): bool mask of shape `(..., length)`
100+
"""
101+
query = self.query(input).transpose(0, 1)
102+
key = self.key(input).transpose(0, 1)
103+
value = self.value(input).transpose(0, 1)
104+
105+
mask = (~mask.bool()).squeeze(-1)
106+
output = self.attn(query, key, value, key_padding_mask=mask)[0].transpose(0, 1)
107+
108+
return output
109+
110+
111+
class ProteinBERTBlock(nn.Module):
112+
"""
113+
Transformer encoding block from
114+
`Attention Is All You Need`_.
115+
116+
.. _Attention Is All You Need:
117+
https://arxiv.org/pdf/1706.03762.pdf
118+
119+
Parameters:
120+
input_dim (int): input dimension
121+
hidden_dim (int): hidden dimension
122+
num_heads (int): number of attention heads
123+
attention_dropout (float, optional): dropout ratio of attention maps
124+
hidden_dropout (float, optional): dropout ratio of hidden features
125+
activation (str or function, optional): activation function
126+
"""
127+
128+
def __init__(self, input_dim, hidden_dim, num_heads, attention_dropout=0,
129+
hidden_dropout=0, activation="relu"):
130+
super(ProteinBERTBlock, self).__init__()
131+
self.input_dim = input_dim
132+
self.num_heads = num_heads
133+
self.attention_dropout = attention_dropout
134+
self.hidden_dropout = hidden_dropout
135+
self.hidden_dim = hidden_dim
136+
137+
self.attention = SelfAttentionBlock(input_dim, num_heads, attention_dropout)
138+
self.linear1 = nn.Linear(input_dim, input_dim)
139+
self.dropout1 = nn.Dropout(hidden_dropout)
140+
self.layer_norm1 = nn.LayerNorm(input_dim)
141+
142+
self.intermediate = layers.MultiLayerPerceptron(input_dim, hidden_dim, activation=activation)
143+
144+
self.linear2 = nn.Linear(hidden_dim, input_dim)
145+
self.dropout2 = nn.Dropout(hidden_dropout)
146+
self.layer_norm2 = nn.LayerNorm(input_dim)
147+
148+
def forward(self, input, mask):
149+
"""
150+
Perform a BERT-block transformation over the input.
151+
152+
Parameters:
153+
input (Tensor): input representations of shape `(..., length, dim)`
154+
mask (Tensor): bool mask of shape `(..., length)`
155+
"""
156+
x = self.attention(input, mask)
157+
x = self.linear1(x)
158+
x = self.dropout1(x)
159+
x = self.layer_norm1(x + input)
160+
161+
hidden = self.intermediate(x)
162+
163+
hidden = self.linear2(hidden)
164+
hidden = self.dropout2(hidden)
165+
output = self.layer_norm2(hidden + x)
166+
167+
return output

torchdrug/layers/common.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
class MultiLayerPerceptron(nn.Module):
1414
"""
1515
Multi-layer Perceptron.
16-
1716
Note there is no batch normalization, activation or dropout in the last layer.
1817
1918
Parameters:
@@ -322,4 +321,19 @@ def forward(self, *args, **kwargs):
322321
else:
323322
args.append(output)
324323

325-
return output
324+
return output
325+
326+
327+
class SinusoidalPositionEmbedding(nn.Module):
328+
329+
def __init__(self, output_dim):
330+
super(SinusoidalPositionEmbedding, self).__init__()
331+
inverse_frequency = 1 / (10000 ** (torch.arange(0.0, output_dim, 2.0) / output_dim))
332+
self.register_buffer("inverse_frequency", inverse_frequency)
333+
334+
def forward(self, input):
335+
# input: [B, L, ...]
336+
positions = torch.arange(input.shape[1] - 1, -1, -1.0, dtype=input.dtype, device=input.device)
337+
sinusoidal_input = torch.outer(positions, self.inverse_frequency)
338+
position_embedding = torch.cat([sinusoidal_input.sin(), sinusoidal_input.cos()], -1)
339+
return position_embedding

torchdrug/layers/conv.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,3 +779,56 @@ def forward(self, graph, input):
779779
def combine(self, input, update):
780780
output = input + update
781781
return output
782+
783+
784+
class GeometricRelationalGraphConv(RelationalGraphConv):
785+
"""
786+
Geometry-aware relational graph convolution operator from
787+
`Protein Representation Learning by Geometric Structure Pretraining`_.
788+
789+
.. _Protein Representation Learning by Geometric Structure Pretraining:
790+
https://arxiv.org/pdf/2203.06125.pdf
791+
792+
Parameters:
793+
input_dim (int): input dimension
794+
output_dim (int): output dimension
795+
num_relation (int): number of relations
796+
edge_input_dim (int, optional): dimension of edge features
797+
batch_norm (bool, optional): apply batch normalization on nodes or not
798+
activation (str or function, optional): activation function
799+
"""
800+
801+
def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"):
802+
super(GeometricRelationalGraphConv, self).__init__(input_dim, output_dim, num_relation, edge_input_dim,
803+
batch_norm, activation)
804+
805+
def aggregate(self, graph, message):
806+
assert graph.num_relation == self.num_relation
807+
808+
node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2]
809+
edge_weight = graph.edge_weight.unsqueeze(-1)
810+
update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node * self.num_relation)
811+
update = update.view(graph.num_node, self.num_relation * self.input_dim)
812+
813+
return update
814+
815+
def message_and_aggregate(self, graph, input):
816+
assert graph.num_relation == self.num_relation
817+
818+
node_in, node_out, relation = graph.edge_list.t()
819+
node_out = node_out * self.num_relation + relation
820+
adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), graph.edge_weight,
821+
(graph.num_node, graph.num_node * graph.num_relation))
822+
update = torch.sparse.mm(adjacency.t(), input)
823+
if self.edge_linear:
824+
edge_input = graph.edge_feature.float()
825+
if self.edge_linear.in_features > self.edge_linear.out_features:
826+
edge_input = self.edge_linear(edge_input)
827+
edge_weight = graph.edge_weight.unsqueeze(-1)
828+
edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
829+
dim_size=graph.num_node * graph.num_relation)
830+
if self.edge_linear.in_features <= self.edge_linear.out_features:
831+
edge_update = self.edge_linear(edge_update)
832+
update += edge_update
833+
834+
return update.view(graph.num_node, self.num_relation * self.input_dim)

torchdrug/layers/functional/functional.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def variadic_sort(input, size, descending=False):
375375
input (Tensor): input of shape :math:`(B, ...)`
376376
size (LongTensor): size of sets of shape :math:`(N,)`
377377
descending (bool, optional): return ascending or descending order
378-
378+
379379
Returns
380380
(Tensor, LongTensor): sorted values and indexes
381381
"""
@@ -385,8 +385,11 @@ def variadic_sort(input, size, descending=False):
385385
mask = ~torch.isinf(input)
386386
max = input[mask].max().item()
387387
min = input[mask].min().item()
388-
safe_input = input.clamp(2 * min - max, 2 * max - min)
389-
offset = (max - min) * 4
388+
abs_max = input[mask].abs().max().item()
389+
# special case: max = min
390+
gap = max - min + abs_max * 1e-6
391+
safe_input = input.clamp(min - gap, max + gap)
392+
offset = gap * 4
390393
if descending:
391394
offset = -offset
392395
input_ext = safe_input + offset * index2sample
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .graph import GraphConstruction, SpatialLineGraph
2+
from .function import BondEdge, KNNEdge, SpatialEdge, SequentialEdge, AlphaCarbonNode, \
3+
IdentityNode, RandomEdgeMask, SubsequenceNode, SubspaceNode
4+
5+
__all__ = [
6+
"GraphConstruction", "SpatialLineGraph",
7+
"BondEdge", "KNNEdge", "SpatialEdge", "SequentialEdge", "AlphaCarbonNode",
8+
"IdentityNode", "RandomEdgeMask", "SubsequenceNode", "SubspaceNode"
9+
]

0 commit comments

Comments
 (0)