1+ from torch import nn
2+ from torch .nn import functional as F
3+
4+ from torchdrug import layers
5+
6+
7+ class ProteinResNetBlock (nn .Module ):
8+ """
9+ Convolutional block with residual connection from `Deep Residual Learning for Image Recognition`_.
10+
11+ .. _Deep Residual Learning for Image Recognition:
12+ https://arxiv.org/pdf/1512.03385.pdf
13+
14+ Parameters:
15+ input_dim (int): input dimension
16+ output_dim (int): output dimension
17+ kernel_size (int, optional): size of convolutional kernel
18+ stride (int, optional): stride of convolution
19+ padding (int, optional): padding added to both sides of the input
20+ activation (str or function, optional): activation function
21+ """
22+
23+ def __init__ (self , input_dim , output_dim , kernel_size = 3 , stride = 1 , padding = 1 , activation = "gelu" ):
24+ super (ProteinResNetBlock , self ).__init__ ()
25+ self .input_dim = input_dim
26+ self .output_dim = output_dim
27+
28+ if isinstance (activation , str ):
29+ self .activation = getattr (F , activation )
30+ else :
31+ self .activation = activation
32+
33+ self .conv1 = nn .Conv1d (input_dim , output_dim , kernel_size , stride , padding , bias = False )
34+ self .layer_norm1 = nn .LayerNorm (output_dim )
35+ self .conv2 = nn .Conv1d (output_dim , output_dim , kernel_size , stride , padding , bias = False )
36+ self .layer_norm2 = nn .LayerNorm (output_dim )
37+
38+ def forward (self , input , mask ):
39+ """
40+ Perform 1D convolutions over the input.
41+
42+ Parameters:
43+ input (Tensor): input representations of shape `(..., length, dim)`
44+ mask (Tensor): bool mask of shape `(..., length, dim)`
45+ """
46+ identity = input
47+
48+ input = input * mask # (B, L, d)
49+ out = self .conv1 (input .transpose (1 , 2 )).transpose (1 , 2 )
50+ out = self .layer_norm1 (out )
51+ out = self .activation (out )
52+
53+ out = out * mask
54+ out = self .conv2 (out .transpose (1 , 2 )).transpose (1 , 2 )
55+ out = self .layer_norm2 (out )
56+
57+ out += identity
58+ out = self .activation (out )
59+
60+ return out
61+
62+
63+ class SelfAttentionBlock (nn .Module ):
64+ """
65+ Multi-head self-attention block from
66+ `Attention Is All You Need`_.
67+
68+ .. _Attention Is All You Need:
69+ https://arxiv.org/pdf/1706.03762.pdf
70+
71+ Parameters:
72+ hidden_dim (int): hidden dimension
73+ num_heads (int): number of attention heads
74+ dropout (float, optional): dropout ratio of attention maps
75+ """
76+
77+ def __init__ (self , hidden_dim , num_heads , dropout = 0.0 ):
78+ super (SelfAttentionBlock , self ).__init__ ()
79+ if hidden_dim % num_heads != 0 :
80+ raise ValueError (
81+ "The hidden size (%d) is not a multiple of the number of attention "
82+ "heads (%d)" % (hidden_dim , num_heads ))
83+ self .hidden_dim = hidden_dim
84+ self .num_heads = num_heads
85+ self .head_size = hidden_dim // num_heads
86+
87+ self .query = nn .Linear (hidden_dim , hidden_dim )
88+ self .key = nn .Linear (hidden_dim , hidden_dim )
89+ self .value = nn .Linear (hidden_dim , hidden_dim )
90+
91+ self .attn = nn .MultiheadAttention (hidden_dim , num_heads , dropout = dropout )
92+
93+ def forward (self , input , mask ):
94+ """
95+ Perform self attention over the input.
96+
97+ Parameters:
98+ input (Tensor): input representations of shape `(..., length, dim)`
99+ mask (Tensor): bool mask of shape `(..., length)`
100+ """
101+ query = self .query (input ).transpose (0 , 1 )
102+ key = self .key (input ).transpose (0 , 1 )
103+ value = self .value (input ).transpose (0 , 1 )
104+
105+ mask = (~ mask .bool ()).squeeze (- 1 )
106+ output = self .attn (query , key , value , key_padding_mask = mask )[0 ].transpose (0 , 1 )
107+
108+ return output
109+
110+
111+ class ProteinBERTBlock (nn .Module ):
112+ """
113+ Transformer encoding block from
114+ `Attention Is All You Need`_.
115+
116+ .. _Attention Is All You Need:
117+ https://arxiv.org/pdf/1706.03762.pdf
118+
119+ Parameters:
120+ input_dim (int): input dimension
121+ hidden_dim (int): hidden dimension
122+ num_heads (int): number of attention heads
123+ attention_dropout (float, optional): dropout ratio of attention maps
124+ hidden_dropout (float, optional): dropout ratio of hidden features
125+ activation (str or function, optional): activation function
126+ """
127+
128+ def __init__ (self , input_dim , hidden_dim , num_heads , attention_dropout = 0 ,
129+ hidden_dropout = 0 , activation = "relu" ):
130+ super (ProteinBERTBlock , self ).__init__ ()
131+ self .input_dim = input_dim
132+ self .num_heads = num_heads
133+ self .attention_dropout = attention_dropout
134+ self .hidden_dropout = hidden_dropout
135+ self .hidden_dim = hidden_dim
136+
137+ self .attention = SelfAttentionBlock (input_dim , num_heads , attention_dropout )
138+ self .linear1 = nn .Linear (input_dim , input_dim )
139+ self .dropout1 = nn .Dropout (hidden_dropout )
140+ self .layer_norm1 = nn .LayerNorm (input_dim )
141+
142+ self .intermediate = layers .MultiLayerPerceptron (input_dim , hidden_dim , activation = activation )
143+
144+ self .linear2 = nn .Linear (hidden_dim , input_dim )
145+ self .dropout2 = nn .Dropout (hidden_dropout )
146+ self .layer_norm2 = nn .LayerNorm (input_dim )
147+
148+ def forward (self , input , mask ):
149+ """
150+ Perform a BERT-block transformation over the input.
151+
152+ Parameters:
153+ input (Tensor): input representations of shape `(..., length, dim)`
154+ mask (Tensor): bool mask of shape `(..., length)`
155+ """
156+ x = self .attention (input , mask )
157+ x = self .linear1 (x )
158+ x = self .dropout1 (x )
159+ x = self .layer_norm1 (x + input )
160+
161+ hidden = self .intermediate (x )
162+
163+ hidden = self .linear2 (hidden )
164+ hidden = self .dropout2 (hidden )
165+ output = self .layer_norm2 (hidden + x )
166+
167+ return output
0 commit comments