Skip to content

Commit cf1a62c

Browse files
committed
add protein transforms
1 parent 921f3f3 commit cf1a62c

File tree

2 files changed

+88
-10
lines changed

2 files changed

+88
-10
lines changed

torchdrug/transforms/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from .transform import TargetNormalize, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, VirtualAtom, Compose
1+
from .transform import NormalizeTarget, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, \
2+
VirtualAtom, TruncateProtein, ProteinView, Compose
23

34
__all__ = [
4-
"TargetNormalize", "RemapAtomType", "RandomBFSOrder", "Shuffle",
5-
"VirtualNode", "VirtualAtom", "Compose",
5+
"NormalizeTarget", "RemapAtomType", "RandomBFSOrder", "Shuffle",
6+
"VirtualNode", "VirtualAtom", "TruncateProtein", "ProteinView", "Compose",
67
]

torchdrug/transforms/transform.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
import copy
22
import logging
33
from collections import deque
4+
from random import randint
45

56
import torch
67

8+
from torchdrug import core
9+
from torchdrug.core import Registry as R
10+
711

812
logger = logging.getLogger(__name__)
913

1014

11-
class TargetNormalize(object):
15+
@R.register("transforms.NormalizeTarget")
16+
class NormalizeTarget(core.Configurable):
1217
"""
1318
Normalize the target values in a sample.
1419
@@ -30,9 +35,11 @@ def __call__(self, item):
3035
return item
3136

3237

33-
class RemapAtomType(object):
38+
@R.register("transforms.RemapAtomType")
39+
class RemapAtomType(core.Configurable):
3440
"""
3541
Map atom types to their index in a vocabulary. Atom types that don't present in the vocabulary are mapped to -1.
42+
3643
Parameters:
3744
atom_types (array_like): vocabulary of atom types
3845
"""
@@ -51,7 +58,8 @@ def __call__(self, item):
5158
return item
5259

5360

54-
class RandomBFSOrder(object):
61+
@R.register("transforms.RandomBFSOrder")
62+
class RandomBFSOrder(core.Configurable):
5563
"""
5664
Order the nodes in a graph according to a random BFS order.
5765
"""
@@ -81,9 +89,11 @@ def __call__(self, item):
8189
return item
8290

8391

84-
class Shuffle(object):
92+
@R.register("transforms.Shuffle")
93+
class Shuffle(core.Configurable):
8594
"""
8695
Shuffle the order of nodes and edges in a graph.
96+
8797
Parameters:
8898
shuffle_node (bool, optional): shuffle node order or not
8999
shuffle_edge (bool, optional): shuffle edge order or not
@@ -125,7 +135,8 @@ def transform_data(self, data, meta):
125135
return new_data
126136

127137

128-
class VirtualNode(object):
138+
@R.register("transforms.VirtualNode")
139+
class VirtualNode(core.Configurable):
129140
"""
130141
Add a virtual node and connect it with every node in the graph.
131142
@@ -199,9 +210,11 @@ def __call__(self, item):
199210
return item
200211

201212

202-
class VirtualAtom(VirtualNode):
213+
@R.register("transforms.VirtualAtom")
214+
class VirtualAtom(VirtualNode, core.Configurable):
203215
"""
204216
Add a virtual atom and connect it with every atom in the molecule.
217+
205218
Parameters:
206219
atom_type (int, optional): type of the virtual atom
207220
bond_type (int, optional): type of the virtual bonds
@@ -215,9 +228,73 @@ def __init__(self, atom_type=None, bond_type=None, node_feature=None, edge_featu
215228
edge_feature=edge_feature, atom_type=atom_type, **kwargs)
216229

217230

218-
class Compose(object):
231+
@R.register("transforms.TruncateProtein")
232+
class TruncateProtein(core.Configurable):
233+
"""
234+
Truncate over long protein sequences into a fixed length.
235+
236+
Parameters:
237+
max_length (int, optional): maximal length of the sequence. Truncate the sequence if it exceeds this limit.
238+
random (bool, optional): truncate the sequence at a random position.
239+
If not, truncate the suffix of the sequence.
240+
keys (str or list of str, optional): keys for the items that require truncation in a sample
241+
"""
242+
243+
def __init__(self, max_length=None, random=False, keys="graph"):
244+
self.truncate_length = max_length
245+
self.random = random
246+
if isinstance(keys, str):
247+
keys = [keys]
248+
self.keys = keys
249+
250+
def __call__(self, item):
251+
new_item = item.copy()
252+
for key in self.keys:
253+
graph = item[key]
254+
if graph.num_residue > self.truncate_length:
255+
if self.random:
256+
start = randint(0, graph.num_residue - self.truncate_length)
257+
else:
258+
start = 0
259+
end = start + self.truncate_length
260+
mask = torch.zeros(graph.num_residue, dtype=torch.bool, device=graph.device)
261+
mask[start:end] = True
262+
graph = graph.subresidue(mask)
263+
264+
new_item[key] = graph
265+
return new_item
266+
267+
268+
@R.register("transforms.ProteinView")
269+
class ProteinView(core.Configurable):
270+
"""
271+
Convert proteins to a specific view.
272+
273+
Parameters:
274+
view (str): protein view. Can be ``atom`` or ``residue``.
275+
keys (str or list of str, optional): keys for the items that require view change in a sample
276+
"""
277+
278+
def __init__(self, view, keys="graph"):
279+
self.view = view
280+
if isinstance(keys, str):
281+
keys = [keys]
282+
self.keys = keys
283+
284+
def __call__(self, item):
285+
item = item.copy()
286+
for key in self.keys:
287+
graph = copy.copy(item[key])
288+
graph.view = self.view
289+
item[key] = graph
290+
return item
291+
292+
293+
@R.register("transforms.Compose")
294+
class Compose(core.Configurable):
219295
"""
220296
Compose a list of transforms into one.
297+
221298
Parameters:
222299
transforms (list of callable): list of transforms
223300
"""

0 commit comments

Comments
 (0)