11import copy
22import logging
33from collections import deque
4+ from random import randint
45
56import torch
67
8+ from torchdrug import core
9+ from torchdrug .core import Registry as R
10+
711
812logger = logging .getLogger (__name__ )
913
1014
11- class TargetNormalize (object ):
15+ @R .register ("transforms.NormalizeTarget" )
16+ class NormalizeTarget (core .Configurable ):
1217 """
1318 Normalize the target values in a sample.
1419
@@ -30,9 +35,11 @@ def __call__(self, item):
3035 return item
3136
3237
33- class RemapAtomType (object ):
38+ @R .register ("transforms.RemapAtomType" )
39+ class RemapAtomType (core .Configurable ):
3440 """
3541 Map atom types to their index in a vocabulary. Atom types that don't present in the vocabulary are mapped to -1.
42+
3643 Parameters:
3744 atom_types (array_like): vocabulary of atom types
3845 """
@@ -51,7 +58,8 @@ def __call__(self, item):
5158 return item
5259
5360
54- class RandomBFSOrder (object ):
61+ @R .register ("transforms.RandomBFSOrder" )
62+ class RandomBFSOrder (core .Configurable ):
5563 """
5664 Order the nodes in a graph according to a random BFS order.
5765 """
@@ -81,9 +89,11 @@ def __call__(self, item):
8189 return item
8290
8391
84- class Shuffle (object ):
92+ @R .register ("transforms.Shuffle" )
93+ class Shuffle (core .Configurable ):
8594 """
8695 Shuffle the order of nodes and edges in a graph.
96+
8797 Parameters:
8898 shuffle_node (bool, optional): shuffle node order or not
8999 shuffle_edge (bool, optional): shuffle edge order or not
@@ -125,7 +135,8 @@ def transform_data(self, data, meta):
125135 return new_data
126136
127137
128- class VirtualNode (object ):
138+ @R .register ("transforms.VirtualNode" )
139+ class VirtualNode (core .Configurable ):
129140 """
130141 Add a virtual node and connect it with every node in the graph.
131142
@@ -199,9 +210,11 @@ def __call__(self, item):
199210 return item
200211
201212
202- class VirtualAtom (VirtualNode ):
213+ @R .register ("transforms.VirtualAtom" )
214+ class VirtualAtom (VirtualNode , core .Configurable ):
203215 """
204216 Add a virtual atom and connect it with every atom in the molecule.
217+
205218 Parameters:
206219 atom_type (int, optional): type of the virtual atom
207220 bond_type (int, optional): type of the virtual bonds
@@ -215,9 +228,73 @@ def __init__(self, atom_type=None, bond_type=None, node_feature=None, edge_featu
215228 edge_feature = edge_feature , atom_type = atom_type , ** kwargs )
216229
217230
218- class Compose (object ):
231+ @R .register ("transforms.TruncateProtein" )
232+ class TruncateProtein (core .Configurable ):
233+ """
234+ Truncate over long protein sequences into a fixed length.
235+
236+ Parameters:
237+ max_length (int, optional): maximal length of the sequence. Truncate the sequence if it exceeds this limit.
238+ random (bool, optional): truncate the sequence at a random position.
239+ If not, truncate the suffix of the sequence.
240+ keys (str or list of str, optional): keys for the items that require truncation in a sample
241+ """
242+
243+ def __init__ (self , max_length = None , random = False , keys = "graph" ):
244+ self .truncate_length = max_length
245+ self .random = random
246+ if isinstance (keys , str ):
247+ keys = [keys ]
248+ self .keys = keys
249+
250+ def __call__ (self , item ):
251+ new_item = item .copy ()
252+ for key in self .keys :
253+ graph = item [key ]
254+ if graph .num_residue > self .truncate_length :
255+ if self .random :
256+ start = randint (0 , graph .num_residue - self .truncate_length )
257+ else :
258+ start = 0
259+ end = start + self .truncate_length
260+ mask = torch .zeros (graph .num_residue , dtype = torch .bool , device = graph .device )
261+ mask [start :end ] = True
262+ graph = graph .subresidue (mask )
263+
264+ new_item [key ] = graph
265+ return new_item
266+
267+
268+ @R .register ("transforms.ProteinView" )
269+ class ProteinView (core .Configurable ):
270+ """
271+ Convert proteins to a specific view.
272+
273+ Parameters:
274+ view (str): protein view. Can be ``atom`` or ``residue``.
275+ keys (str or list of str, optional): keys for the items that require view change in a sample
276+ """
277+
278+ def __init__ (self , view , keys = "graph" ):
279+ self .view = view
280+ if isinstance (keys , str ):
281+ keys = [keys ]
282+ self .keys = keys
283+
284+ def __call__ (self , item ):
285+ item = item .copy ()
286+ for key in self .keys :
287+ graph = copy .copy (item [key ])
288+ graph .view = self .view
289+ item [key ] = graph
290+ return item
291+
292+
293+ @R .register ("transforms.Compose" )
294+ class Compose (core .Configurable ):
219295 """
220296 Compose a list of transforms into one.
297+
221298 Parameters:
222299 transforms (list of callable): list of transforms
223300 """
0 commit comments