Skip to content

Commit cd45518

Browse files
committed
clean up
1 parent cf1a62c commit cd45518

File tree

19 files changed

+95
-98
lines changed

19 files changed

+95
-98
lines changed

doc/source/quick_start.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ utilization of hardware. They can also be transferred between CPUs and GPUs usin
7474

7575
.. code:: bash
7676
77-
PackedMolecule(batch_size=4, num_nodes=[12, 6, 14, 9], num_edges=[22, 10, 30, 18],
77+
PackedMolecule(batch_size=4, num_atoms=[12, 6, 14, 9], num_bonds=[22, 10, 30, 18],
7878
device='cuda:0')
7979
8080
Just like original PyTorch tensors, graphs support a wide range of indexing

torchdrug/data/dataset.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ def load_csv(self, csv_file, smiles_field="smiles", target_fields=None, verbose=
116116
self.load_smiles(smiles, targets, verbose=verbose, **kwargs)
117117

118118
def load_pickle(self, pkl_file, verbose=0):
119+
"""
120+
Load the dataset from a pickle file.
121+
122+
Parameters:
123+
pkl_file (str): file name
124+
verbose (int, optional): output verbose level
125+
"""
119126
with utils.smart_open(pkl_file, "rb") as fin:
120127
num_sample, tasks = pickle.load(fin)
121128

@@ -133,6 +140,13 @@ def load_pickle(self, pkl_file, verbose=0):
133140
self.targets[task] = value
134141

135142
def save_pickle(self, pkl_file, verbose=0):
143+
"""
144+
Save the dataset to a pickle file.
145+
146+
Parameters:
147+
pkl_file (str): file name
148+
verbose (int, optional): output verbose level
149+
"""
136150
with utils.smart_open(pkl_file, "wb") as fout:
137151
num_sample = len(self.data)
138152
tasks = self.targets.keys()
@@ -659,16 +673,16 @@ def load_sequence(self, sequences, targets, attributes=None, transform=None, laz
659673
self.targets[field].append(targets[field][i])
660674

661675
@utils.copy_args(load_sequence)
662-
def load_lmdbs(self, lmdb_files, number_field="num_examples", sequence_field="primary", target_fields=None,
676+
def load_lmdbs(self, lmdb_files, sequence_field="primary", target_fields=None, number_field="num_examples",
663677
transform=None, lazy=False, verbose=0, **kwargs):
664678
"""
665679
Load the dataset from lmdb files.
666680
667681
Parameters:
668682
lmdb_files (list of str): list of lmdb files
669-
number_field (str, optional): name of the field of sample count in lmdb files
670683
sequence_field (str, optional): name of the field of protein sequence in lmdb files
671684
target_fields (list of str, optional): name of target fields in lmdb files
685+
number_field (str, optional): name of the field of sample count in lmdb files
672686
transform (Callable, optional): protein sequence transformation function
673687
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
674688
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
@@ -701,12 +715,13 @@ def load_lmdbs(self, lmdb_files, number_field="num_examples", sequence_field="pr
701715
self.num_samples = num_samples
702716

703717
@utils.copy_args(data.Protein.from_molecule)
704-
def load_pdbs(self, pdb_files, transform=None, lazy=False, verbose=0, **kwargs):
718+
def load_pdbs(self, pdb_files, sanitize=True, transform=None, lazy=False, verbose=0, **kwargs):
705719
"""
706720
Load the dataset from pdb files.
707721
708722
Parameters:
709723
pdb_files (list of str): pdb file names
724+
sanitize (bool, optional): whether to sanitize the molecule
710725
transform (Callable, optional): protein sequence transformation function
711726
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
712727
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
@@ -729,7 +744,6 @@ def load_pdbs(self, pdb_files, transform=None, lazy=False, verbose=0, **kwargs):
729744
pdb_files = tqdm(pdb_files, "Constructing proteins from pdbs")
730745
for i, pdb_file in enumerate(pdb_files):
731746
if not lazy or i == 0:
732-
sanitize = kwargs.pop("sanitize", True)
733747
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize)
734748
if not mol:
735749
logger.debug("Can't construct molecule from pdb file `%s`. Ignore this sample." % pdb_file)
@@ -779,10 +793,10 @@ def load_fasta(self, fasta_file, verbose=0, **kwargs):
779793
@utils.copy_args(data.Protein.from_molecule)
780794
def load_pickle(self, pkl_file, transform=None, lazy=False, verbose=0, **kwargs):
781795
"""
782-
Load the dataset from pickle files.
796+
Load the dataset from a pickle file.
783797
784798
Parameters:
785-
pkl_file (str): pickle file name
799+
pkl_file (str): file name
786800
transform (Callable, optional): protein sequence transformation function
787801
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
788802
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
@@ -808,13 +822,6 @@ def load_pickle(self, pkl_file, transform=None, lazy=False, verbose=0, **kwargs)
808822
self.data.append(protein)
809823

810824
def save_pickle(self, pkl_file, verbose=0):
811-
"""
812-
Save the dataset to pickle files.
813-
814-
Parameters:
815-
pkl_file (str): pickle file name
816-
verbose (int, optional): output verbose level
817-
"""
818825
with utils.smart_open(pkl_file, "wb") as fout:
819826
num_sample = len(self.data)
820827
pickle.dump(num_sample, fout)
@@ -890,16 +897,16 @@ def load_sequence(self, sequences, targets, attributes=None, transform=None, laz
890897
self.targets[field].append(targets[field][i])
891898

892899
@utils.copy_args(load_sequence)
893-
def load_lmdbs(self, lmdb_files, number_field="num_examples", sequence_field="primary", target_fields=None,
900+
def load_lmdbs(self, lmdb_files, sequence_field="primary", target_fields=None, number_field="num_examples",
894901
transform=None, lazy=False, verbose=0, **kwargs):
895902
"""
896903
Load the dataset from lmdb files.
897904
898905
Parameters:
899906
lmdb_files (list of str): file names
900-
number_field (str, optional): name of the field of sample count in lmdb files
901907
sequence_field (str or list of str, optional): names of the fields of protein sequence in lmdb files
902908
target_fields (list of str, optional): name of target fields in lmdb files
909+
number_field (str, optional): name of the field of sample count in lmdb files
903910
transform (Callable, optional): protein sequence transformation function
904911
lazy (bool, optional): if lazy mode is used, the protein pairs are processed in the dataloader.
905912
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
@@ -1022,17 +1029,17 @@ def load_sequence(self, sequences, smiles, targets, num_samples, attributes=None
10221029
return num_samples
10231030

10241031
@utils.copy_args(load_sequence)
1025-
def load_lmdbs(self, lmdb_files, number_field="num_examples", sequence_field="target", smiles_field="drug",
1026-
target_fields=None, transform=None, lazy=False, verbose=0, **kwargs):
1032+
def load_lmdbs(self, lmdb_files, sequence_field="target", smiles_field="drug", target_fields=None,
1033+
number_field="num_examples", transform=None, lazy=False, verbose=0, **kwargs):
10271034
"""
10281035
Load the dataset from lmdb files.
10291036
10301037
Parameters:
10311038
lmdb_files (list of str): file names
1032-
number_field (str, optional): name of the field of sample count in lmdb files
10331039
sequence_field (str, optional): name of the field of protein sequence in lmdb files
10341040
smiles_field (str, optional): name of the field of ligand SMILES string in lmdb files
10351041
target_fields (list of str, optional): name of target fields in lmdb files
1042+
number_field (str, optional): name of the field of sample count in lmdb files
10361043
transform (Callable, optional): protein sequence transformation function
10371044
lazy (bool, optional): if lazy mode is used, the protein-ligand pairs are processed in the dataloader.
10381045
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.

torchdrug/data/feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def residue_symbol(residue):
307307

308308
@R.register("features.residue.default")
309309
def residue_default(residue):
310-
"""Default atom feature.
310+
"""Default residue feature.
311311
312312
Features:
313313
GetResidueName(): one-hot embedding for the residue symbol

torchdrug/data/protein.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,7 @@ def from_molecule(cls, mol, atom_feature="default", bond_feature="default", resi
252252
meta_dict=protein.meta_dict, **protein.data_dict)
253253

254254
@classmethod
255-
def from_sequence_fast(cls, sequence):
256-
"""
257-
A faster version of creating a protein from a sequence.
258-
259-
Parameters:
260-
sequence (str): string
261-
"""
255+
def _residue_from_sequence(cls, sequence):
262256
residue_type = []
263257
residue_feature = []
264258
sequence = sequence + "G"
@@ -278,10 +272,16 @@ def from_sequence_fast(cls, sequence):
278272
@classmethod
279273
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
280274
def from_sequence(cls, sequence, atom_feature="default", bond_feature="default", residue_feature="default",
281-
mol_feature=None, kekulize=False, residue_only=False):
275+
mol_feature=None, kekulize=False):
282276
"""
283277
Create a protein from a sequence.
284278
279+
.. note::
280+
281+
It takes considerable time to construct proteins with a large number of atoms and bonds.
282+
If you only need residue information, you may speed up the construction by setting
283+
``atom_feature`` and ``bond_feature`` to ``None``.
284+
285285
Parameters:
286286
sequence (str): protein sequence
287287
atom_feature (str or list of str, optional): atom features to extract
@@ -292,14 +292,9 @@ def from_sequence(cls, sequence, atom_feature="default", bond_feature="default",
292292
Note this only affects the relation in ``edge_list``.
293293
For ``bond_type``, aromatic bonds are always stored explicitly.
294294
By default, aromatic bonds are stored.
295-
residue_only (bool, optional): only store residue information without atom information.
296-
This can speed up the processing.
297295
"""
298-
if residue_only:
299-
if residue_feature != "default":
300-
raise ValueError("`residue_only` only supports the default residue feature, "
301-
"but found `%s` for `residue_feature`" % residue_feature)
302-
return cls.from_sequence_fast(sequence)
296+
if atom_feature is None and bond_feature is None and residue_feature == "default":
297+
return cls._residue_from_sequence(sequence)
303298

304299
mol = Chem.MolFromSequence(sequence)
305300
if mol is None:
@@ -324,7 +319,7 @@ def from_pdb(cls, pdb_file, atom_feature="default", bond_feature="default", resi
324319
Note this only affects the relation in ``edge_list``.
325320
For ``bond_type``, aromatic bonds are always stored explicitly.
326321
By default, aromatic bonds are stored.
327-
sanitize (bool, optional): whether to sanitize the molecule.
322+
sanitize (bool, optional): whether to sanitize the molecule
328323
"""
329324
if not os.path.exists(pdb_file):
330325
raise FileNotFoundError("No such file `%s`" % pdb_file)
@@ -524,7 +519,7 @@ def repeat(self, count):
524519
num_relation=num_relation, meta_dict=self.meta_dict, **data_dict)
525520

526521
def residue2atom(self, residue_index):
527-
"""Map residue id to atom ids."""
522+
"""Map residue ids to atom ids."""
528523
residue_index = self._standarize_index(residue_index, self.num_residue)
529524
if not hasattr(self, "node_inverted_index"):
530525
self.node_inverted_index = self._build_node_inverted_index()
@@ -992,7 +987,7 @@ def from_molecule(cls, mols, atom_feature="default", bond_feature="default", res
992987
offsets=protein._offsets, meta_dict=protein.meta_dict, **protein.data_dict)
993988

994989
@classmethod
995-
def from_sequence_fast(cls, sequences):
990+
def _residue_from_sequence(cls, sequences):
996991
num_residues = []
997992
residue_type = []
998993
residue_feature = []
@@ -1021,10 +1016,16 @@ def from_sequence_fast(cls, sequences):
10211016
@classmethod
10221017
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
10231018
def from_sequence(cls, sequences, atom_feature="default", bond_feature="default", residue_feature="default",
1024-
mol_feature=None, kekulize=False, residue_only=False):
1019+
mol_feature=None, kekulize=False):
10251020
"""
10261021
Create a packed protein from a list of sequences.
10271022
1023+
.. note::
1024+
1025+
It takes considerable time to construct proteins with a large number of atoms and bonds.
1026+
If you only need residue information, you may speed up the construction by setting
1027+
``atom_feature`` and ``bond_feature`` to ``None``.
1028+
10281029
Parameters:
10291030
sequences (str): list of protein sequences
10301031
atom_feature (str or list of str, optional): atom features to extract
@@ -1035,14 +1036,9 @@ def from_sequence(cls, sequences, atom_feature="default", bond_feature="default"
10351036
Note this only affects the relation in ``edge_list``.
10361037
For ``bond_type``, aromatic bonds are always stored explicitly.
10371038
By default, aromatic bonds are stored.
1038-
residue_only (bool, optional): only store residue information without atom information.
1039-
This can speed up the processing.
10401039
"""
1041-
if residue_only:
1042-
if residue_feature != "default":
1043-
raise ValueError("`residue_only` only supports the default residue feature, "
1044-
"but found `%s` for `residue_feature`" % residue_feature)
1045-
return cls.from_sequence_fast(sequences)
1040+
if atom_feature is None and bond_feature is None and residue_feature == "default":
1041+
return cls._residue_from_sequence(sequences)
10461042

10471043
mols = []
10481044
for sequence in sequences:
@@ -1056,7 +1052,7 @@ def from_sequence(cls, sequences, atom_feature="default", bond_feature="default"
10561052
@classmethod
10571053
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
10581054
def from_pdb(cls, pdb_files, atom_feature="default", bond_feature="default", residue_feature="default",
1059-
mol_feature=None, kekulize=False):
1055+
mol_feature=None, kekulize=False, sanitize=False):
10601056
"""
10611057
Create a protein from a list of PDB files.
10621058
@@ -1070,10 +1066,11 @@ def from_pdb(cls, pdb_files, atom_feature="default", bond_feature="default", res
10701066
Note this only affects the relation in ``edge_list``.
10711067
For ``bond_type``, aromatic bonds are always stored explicitly.
10721068
By default, aromatic bonds are stored.
1069+
sanitize (bool, optional): whether to sanitize the molecule
10731070
"""
10741071
mols = []
10751072
for pdb_file in pdb_files:
1076-
mol = Chem.MolFromPDBFile(pdb_file)
1073+
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize)
10771074
mols.append(mol)
10781075

10791076
return cls.from_molecule(mols, atom_feature, bond_feature, residue_feature, mol_feature, kekulize)

torchdrug/datasets/alphafolddb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
@R.register("datasets.AlphaFoldDB")
9-
@utils.copy_args(data.ProteinDataset.load_pdbs, ignore=("filtered_pdb",))
9+
@utils.copy_args(data.ProteinDataset.load_pdbs)
1010
class AlphaFoldDB(data.ProteinDataset):
1111
"""
1212
3D protein structures predicted by AlphaFold.

torchdrug/datasets/enzyme_commission.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
@R.register("datasets.EnzymeCommission")
13-
@utils.copy_args(data.ProteinDataset.load_pdbs, ignore=("filtered_pdb",))
13+
@utils.copy_args(data.ProteinDataset.load_pdbs)
1414
class EnzymeCommission(data.ProteinDataset):
1515
"""
1616
A set of proteins with their 3D structures and EC numbers, which describes their
@@ -23,7 +23,7 @@ class EnzymeCommission(data.ProteinDataset):
2323
2424
Parameters:
2525
path (str): the path to store the dataset
26-
test_cutoff (float): the test cutoff used to split the dataset
26+
test_cutoff (float, optional): the test cutoff used to split the dataset
2727
verbose (int, optional): output verbose level
2828
**kwargs
2929
"""
@@ -47,13 +47,14 @@ def __init__(self, path, test_cutoff=0.95, verbose=1, **kwargs):
4747
pkl_file = os.path.join(path, self.processed_file)
4848

4949
csv_file = os.path.join(path, "nrPDB-EC_test.csv")
50-
filtered_pdb = set()
50+
pdb_ids = []
5151
with open(csv_file, "r") as fin:
5252
reader = csv.reader(fin, delimiter=",")
5353
idx = self.test_cutoffs.index(test_cutoff) + 1
5454
_ = next(reader)
5555
for line in reader:
56-
if line[idx] == "0": filtered_pdb.add(line[0])
56+
if line[idx] == "0":
57+
pdb_ids.append(line[0])
5758

5859
if os.path.exists(pkl_file):
5960
self.load_pickle(pkl_file, verbose=verbose, **kwargs)
@@ -64,8 +65,8 @@ def __init__(self, path, test_cutoff=0.95, verbose=1, **kwargs):
6465
pdb_files += sorted(glob.glob(os.path.join(split_path, split, "*.pdb")))
6566
self.load_pdbs(pdb_files, verbose=verbose, **kwargs)
6667
self.save_pickle(pkl_file, verbose=verbose)
67-
if len(filtered_pdb) > 0:
68-
self.filter_pdb(filtered_pdb)
68+
if len(pdb_ids) > 0:
69+
self.filter_pdb(pdb_ids)
6970

7071
tsv_file = os.path.join(path, "nrPDB-EC_annot.tsv")
7172
pdb_ids = [os.path.basename(pdb_file).split("_")[0] for pdb_file in self.pdb_files]
@@ -74,12 +75,13 @@ def __init__(self, path, test_cutoff=0.95, verbose=1, **kwargs):
7475
splits = [os.path.basename(os.path.dirname(pdb_file)) for pdb_file in self.pdb_files]
7576
self.num_samples = [splits.count("train"), splits.count("valid"), splits.count("test")]
7677

77-
def filter_pdb(self, filtered_pdb):
78+
def filter_pdb(self, pdb_ids):
79+
pdb_ids = set(pdb_ids)
7880
sequences = []
7981
pdb_files = []
8082
data = []
8183
for sequence, pdb_file, protein in zip(self.sequences, self.pdb_files, self.data):
82-
if os.path.basename(pdb_file).split("_")[0] in filtered_pdb:
84+
if os.path.basename(pdb_file).split("_")[0] in pdb_ids:
8385
continue
8486
sequences.append(sequence)
8587
pdb_files.append(pdb_file)

0 commit comments

Comments
 (0)