diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 7009406d..dc6c719b 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -329,7 +329,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: except RuntimeError as e: print(f"RuntimeError at forward: {e}") print(f'data[features]: {data["features"]}') - raise Exception + raise e inp = self.word_dropout(inp) electra = self.electra(inputs_embeds=inp, **kwargs) d = electra.last_hidden_state[:, 0, :] diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py new file mode 100644 index 00000000..c9c6f912 --- /dev/null +++ b/chebai/models/ffn.py @@ -0,0 +1,153 @@ +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import Tensor, nn + +from chebai.models import ChebaiBaseNet + + +class FFN(ChebaiBaseNet): + # Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139 + + NAME = "FFN" + + def __init__( + self, + input_size: int, + hidden_layers: List[int] = [ + 1024, + ], + **kwargs + ): + super().__init__(**kwargs) + + layers = [] + current_layer_input_size = input_size + for hidden_dim in hidden_layers: + layers.append(MLPBlock(current_layer_input_size, hidden_dim)) + layers.append(Residual(MLPBlock(hidden_dim, hidden_dim))) + current_layer_input_size = hidden_dim + + layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim)) + layers.append(nn.Sigmoid()) + self.model = nn.Sequential(*layers) + + def _get_prediction_and_labels(self, data, labels, model_output): + d = model_output["logits"] + loss_kwargs = data.get("loss_kwargs", dict()) + if "non_null_labels" in loss_kwargs: + n = loss_kwargs["non_null_labels"] + d = d[n] + return torch.sigmoid(d), labels.int() if labels is not None else None + + def _process_for_loss( + self, + model_output: Dict[str, Tensor], + labels: Tensor, + loss_kwargs: Dict[str, Any], + ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: + """ + Process the model output for calculating the loss. + + Args: + model_output (Dict[str, Tensor]): The output of the model. + labels (Tensor): The target labels. + loss_kwargs (Dict[str, Any]): Additional loss arguments. + + Returns: + tuple: A tuple containing the processed model output, labels, and loss arguments. + """ + kwargs_copy = dict(loss_kwargs) + if labels is not None: + labels = labels.float() + return model_output["logits"], labels, kwargs_copy + + def forward(self, data, **kwargs): + x = data["features"] + return {"logits": self.model(x)} + + +class Residual(nn.Module): + """ + A residual layer that adds the output of a function to its input. + + Args: + fn (nn.Module): The function to be applied to the input. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L6-L35 + """ + + def __init__(self, fn): + """ + Initialize the Residual layer with a given function. + + Args: + fn (nn.Module): The function to be applied to the input. + """ + super().__init__() + self.fn = fn + + def forward(self, x): + """ + Forward pass of the Residual layer. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: The input tensor added to the result of applying the function `fn` to it. + """ + return x + self.fn(x) + + +class MLPBlock(nn.Module): + """ + A basic Multi-Layer Perceptron (MLP) block with one fully connected layer. + + Args: + in_features (int): The number of input features. + output_size (int): The number of output features. + bias (boolean): Add bias to the linear layer + layer_norm (boolean): Apply layer normalization + dropout (float): The dropout value + activation (nn.Module): The activation function to be applied after each fully connected layer. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L38-L73 + + Example: + ```python + # Create an MLP block with 2 hidden layers and ReLU activation + mlp_block = MLPBlock(input_size=64, output_size=10, activation=nn.ReLU()) + + # Apply the MLP block to an input tensor + input_tensor = torch.randn(32, 64) + output = mlp_block(input_tensor) + ``` + """ + + def __init__( + self, + in_features, + out_features, + bias=True, + layer_norm=True, + dropout=0.1, + activation=nn.ReLU, + ): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias) + self.activation = activation() + self.layer_norm: Optional[nn.LayerNorm] = ( + nn.LayerNorm(out_features) if layer_norm else None + ) + self.dropout: Optional[nn.Dropout] = nn.Dropout(dropout) if dropout else None + + def forward(self, x): + x = self.activation(self.linear(x)) + if self.layer_norm: + x = self.layer_norm(x) + if self.dropout: + x = self.dropout(x) + return x diff --git a/chebai/preprocessing/bin/protein_token/tokens.txt b/chebai/preprocessing/bin/protein_token/tokens.txt index 72ad1b6d..c31c5b72 100644 --- a/chebai/preprocessing/bin/protein_token/tokens.txt +++ b/chebai/preprocessing/bin/protein_token/tokens.txt @@ -18,3 +18,4 @@ W E V H +X diff --git a/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt b/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt index 69dca126..534e5db1 100644 --- a/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt +++ b/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt @@ -7998,3 +7998,362 @@ WWC WCC WCH WWM +TAX +AXD +XDR +IEX +EXV +QAX +AXX +XXE +XES +MXN +XNF +NRX +RXX +XXX +XXR +XRI +SAX +AXG +XGG +PRX +RXR +XRX +RXE +XEF +QEX +EXQ +XQR +REX +EXR +RXQ +XQQ +DRX +RXP +XPG +QMX +MXT +XTX +TXR +XRM +APX +PXX +XXG +XGI +NLX +LXX +XXM +XMA +LNX +NXE +XEA +GTX +TXN +XND +LIX +IXI +XIM +MVX +VXX +XXK +XKT +GLX +LXP +XPP +QGX +GXD +XDL +XAP +QNX +NXM +XMN +VAX +XGV +IKX +KXY +KEX +EXL +XLY +GQX +QXE +XEP +PLX +XKC +PVX +XKE +RXI +XIR +AXL +XLN +LLX +LXD +XDA +AXE +XEL +GGX +GXG +KAX +XXA +XAG +XWS +SPX +PXC +XCD +GWX +WXH +XHF +MPX +ESX +SXN +XNK +DLX +LXN +XNS +QXG +XGD +ITX +XRG +NEX +EXA +XAL +LDX +DXI +XII +TPX +PXM +XMR +NXG +XGY +ASX +SXV +XVE +TKX +KXA +KRX +XXT +XTL +IDX +DXX +XXL +XLV +AKX +KXX +QHX +HXV +XVN +NSX +SXX +XKX +XDP +DAX +AXK +XKQ +PIX +IXX +XXF +VLX +XDI +DIX +IXL +XLK +LKX +KXV +XVA +DNX +NXD +ILX +LXK +XKV +VYX +YXE +XEI +RXS +XSH +KGX +XGF +AVX +VXY +XYG +HVX +XXI +XID +TVX +XXS +XSA +ENX +NXX +XMD +IIX +XMQ +AEX +EXX +XME +PGX +GXP +XPR +SKX +KXF +XFT +HRX +XSW +PQX +XGR +QQX +VTX +XRP +PSX +SXP +XPL +VGX +GXY +RSX +SXS +XSL +VSX +XST +AXV +XVL +AGX +GXX +XTK +KLX +LXR +XRV +AHX +HXC +XCS +LVX +VXN +XNR +NGX +GXL +TSX +SXQ +XQN +KXL +XLL +VIX +IXG +XGA +GFX +FXG +XGL +PTX +TXT +XTS +EMX +MXQ +SXY +XYA +IQX +QXY +XYR +TXK +IGX +XPS +PXT +XTG +NXQ +VKX +KXS +XSN +GVX +VXE +GRX +XRE +YKX +KXE +XEE +EEX +EXT +XTI +EHX +HXN +XNL +NDX +DXD +IAX +KSX +SXL +RRX +XRK +DDX +DXE +RXG +VXL +XLS +DTX +TXG +VXF +XFA +XIG +VXT +XTA +ISX +SXR +XRY +VQX +QXP +XPC +LGX +GXS +HGX +XGH +XXD +XDD +KKX +XXV +PKX +XLT +XSP +XLD +RAX +AXS +XSI +IYX +YXX +XXP +XPI +MSX +SXT +GEX +XHP +LFX +FXX +VXI +XIW +QTX +TXX +XXQ +XQA +FLX +DXN +XNC +MXS +XSR +YLX +EQX +QXS +TMX +MXC +XCY +NXA +XAV +EXE +XEQ +HPX +PXP +LMX +MXX +KTX +XKK +XXH +XHS +MKX +XIH +WRX +XKS +EXY +XYQ +QKX diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 857a5862..817bc1d1 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -155,8 +155,19 @@ def fold_dir(self) -> str: return f"cv_{self.inner_k_folds}_fold" @property + @abstractmethod def _name(self) -> str: - raise NotImplementedError + """ + Abstract property representing the name of the data module. + + This property should be implemented in subclasses to provide a unique name for the data module. + The name is used to create subdirectories within the base directory or `processed_dir_main` + for storing relevant data associated with this module. + + Returns: + str: The name of the data module. + """ + pass def _filter_labels(self, row: dict) -> dict: """ @@ -729,7 +740,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: processed_name = self.processed_main_file_names_dict["data"] if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): - print("Missing processed data file (`data.pkl` file)") + print(f"Missing processed data file (`{processed_name}` file)") os.makedirs(self.processed_dir_main, exist_ok=True) data_path = self._download_required_data() g = self._extract_class_hierarchy(data_path) @@ -813,7 +824,10 @@ def setup_processed(self) -> None: None """ os.makedirs(self.processed_dir, exist_ok=True) - print("Missing transformed data (`data.pt` file). Transforming data.... ") + transformed_file_name = self.processed_file_names_dict["data"] + print( + f"Missing transformed data (`{transformed_file_name}` file). Transforming data.... " + ) torch.save( self._load_data_from_file( os.path.join( @@ -821,7 +835,7 @@ def setup_processed(self) -> None: self.processed_main_file_names_dict["data"], ) ), - os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), + os.path.join(self.processed_dir, transformed_file_name), ) @staticmethod diff --git a/chebai/preprocessing/datasets/deepGO/__init__.py b/chebai/preprocessing/datasets/deepGO/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/deepGO/go_uniprot.py similarity index 69% rename from chebai/preprocessing/datasets/go_uniprot.py rename to chebai/preprocessing/datasets/deepGO/go_uniprot.py index a2c4ae54..1b0eb2aa 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/deepGO/go_uniprot.py @@ -1,18 +1,29 @@ -# Reference for this file : +# References for this file : +# Reference 1: # Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf; # DeepGO: Predicting protein functions from sequence and interactions # using a deep ontology-aware classifier, Bioinformatics, 2017. # https://doi.org/10.1093/bioinformatics/btx624 # Github: https://github.com/bio-ontology-research-group/deepgo + +# Reference 2: # https://www.ebi.ac.uk/GOA/downloads # https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt # https://www.uniprot.org/uniprotkb +# Reference 3: +# Kulmanov, M., Guzmán-Vega, F.J., Duek Roggli, +# P. et al. Protein function prediction as approximate semantic entailment. Nat Mach Intell 6, 220–228 (2024). +# https://doi.org/10.1038/s42256-024-00795-w +# https://github.com/bio-ontology-research-group/deepgo2 + __all__ = [ "GOUniProtOver250", "GOUniProtOver50", "EXPERIMENTAL_EVIDENCE_CODES", "AMBIGUOUS_AMINO_ACIDS", + "DeepGO1MigratedData", + "DeepGO2MigratedData", ] import gzip @@ -29,11 +40,13 @@ import pandas as pd import requests import torch +import tqdm from Bio import SwissProt from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import _DynamicDataset +# https://github.com/bio-ontology-research-group/deepgo/blob/master/utils.py#L15 EXPERIMENTAL_EVIDENCE_CODES = { "EXP", "IDA", @@ -43,10 +56,19 @@ "IEP", "TAS", "IC", + # New evidence codes added in latest paper year 2024 Reference number 3 + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L24-L26 + "HTP", + "HDA", + "HMP", + "HGI", + "HEP", } # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 -AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"} +# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L10 +# `X` is now considered as valid amino acid, as per latest paper year 2024 Refernce number 3 +AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "Z", "*"} class _GOUniProtDataExtractor(_DynamicDataset, ABC): @@ -56,10 +78,16 @@ class _GOUniProtDataExtractor(_DynamicDataset, ABC): Args: dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. - **kwargs: Additional keyword arguments passed to XYBaseDataModule. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. + **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. Attributes: dynamic_data_split_seed (int): The seed for random data splitting, default is 42. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. splits_file_path (Optional[str]): Path to the CSV file containing split assignments. """ @@ -405,12 +433,9 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: Note: This mapping is necessary because the GO data does not include the protein sequence representation. - - Quote from the DeepGo Paper: - `We select proteins with annotations having experimental evidence codes - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC) and filter the proteins by a - maximum length of 1002, ignoring proteins with ambiguous amino acid codes - (B, O, J, U, X, Z) in their sequence.` + We select proteins with annotations having experimental evidence codes, as specified in + `EXPERIMENTAL_EVIDENCE_CODES` and filter the proteins by a maximum length of 1002, ignoring proteins with + ambiguous amino acid codes specified in `AMBIGUOUS_AMINO_ACIDS` in their sequence. Check the link below for keyword details: https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt @@ -591,9 +616,6 @@ class _GOUniProtOverX(_GOUniProtDataExtractor, ABC): Attributes: READER (dr.ProteinDataReader): The reader used for reading the dataset. THRESHOLD (int): The threshold for selecting classes based on the number of subclasses. - - Property: - label_number (int): The number of labels in the dataset. This property must be implemented by subclasses. """ READER: dr.ProteinDataReader = dr.ProteinDataReader @@ -709,3 +731,277 @@ class GOUniProtOver50(_GOUniProtOverX): """ THRESHOLD: int = 50 + + +class _DeepGOMigratedData(_GOUniProtDataExtractor, ABC): + """ + Base class for use of the migrated DeepGO data with common properties, name formatting, and file paths. + + Attributes: + READER (dr.ProteinDataReader): Protein data reader class. + THRESHOLD (Optional[int]): Threshold value for GO class selection, + determined by the GO branch type in derived classes. + """ + + READER: dr.ProteinDataReader = dr.ProteinDataReader + THRESHOLD: Optional[int] = None + + # Mapping from GO branch conventions used in DeepGO to our conventions + GO_BRANCH_MAPPING: dict = { + "cc": "CC", + "mf": "MF", + "bp": "BP", + } + + @property + def _name(self) -> str: + """ + Generates a unique identifier for the migrated data based on the GO + branch and max sequence length, optionally including a threshold. + + Returns: + str: A formatted name string for the data. + """ + threshold_part = f"GO{self.THRESHOLD}_" if self.THRESHOLD is not None else "GO_" + + if self.go_branch != self._ALL_GO_BRANCHES: + return f"{threshold_part}{self.go_branch}_{self.max_sequence_length}" + + return f"{threshold_part}{self.max_sequence_length}" + + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Checks for the existence of migrated DeepGO data in the specified directory. + Raises an error if the required data file is not found, prompting + migration from DeepGO to this data structure. + + Args: + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Raises: + FileNotFoundError: If the processed data file does not exist. + """ + print("Checking for processed data in", self.processed_dir_main) + + processed_name = self.processed_main_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): + raise FileNotFoundError( + f"File {processed_name} not found.\n" + f"You must run the appropriate DeepGO migration script " + f"(chebai/preprocessing/migration/deep_go) before executing this configuration " + f"to migrate data from DeepGO to this data structure." + ) + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + # Selection of GO classes not needed for migrated data + pass + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + @abstractmethod + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Abstract property for defining main processed file names. + These files are stored in the same directory as the generated data files + but have distinct names to differentiate them during training. + + Returns: + dict: A dictionary with key-value pairs for main processed file names. + """ + pass + + @property + @abstractmethod + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Abstract property for defining additional processed file names. + These files are stored in the same directory as the generated data files + but have distinct names to differentiate them during training. + + Returns: + dict: A dictionary with key-value pairs for processed file names. + """ + pass + + +class DeepGO1MigratedData(_DeepGOMigratedData): + """ + Migrated data class specific to DeepGO1. Sets threshold values according + to the research paper based on the GO branch. + + Note: + Refer reference number 1 at the top of this file for the corresponding research paper. + + Args: + **kwargs: Arbitrary keyword arguments passed to the superclass. + + Raises: + ValueError: If an unsupported GO branch is provided. + """ + + def __init__(self, **kwargs): + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + assert int(kwargs.get("max_sequence_length")) == 1002 + + # Set threshold based on GO branch, as per DeepGO1 paper and its data. + if kwargs.get("go_branch") in ["CC", "MF"]: + self.THRESHOLD = 50 + elif kwargs.get("go_branch") == "BP": + self.THRESHOLD = 250 + else: + raise ValueError( + f"DeepGO1 paper has no defined threshold for branch {self.go_branch}" + ) + + super(_DeepGOMigratedData, self).__init__(**kwargs) + + @property + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Returns main processed file names specific to DeepGO1. + + Returns: + dict: Dictionary with the main data file name for DeepGO1. + """ + return {"data": "data_deep_go1.pkl"} + + @property + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Returns processed file names specific to DeepGO1. + + Returns: + dict: Dictionary with data file name for DeepGO1. + """ + return {"data": "data_deep_go1.pt"} + + +class DeepGO2MigratedData(_DeepGOMigratedData): + """ + Migrated data class specific to DeepGO2, inheriting from DeepGO1MigratedData + with different processed file names. + + Note: + Refer reference number 3 at the top of this file for the corresponding research paper. + + Returns: + dict: Dictionary with file names specific to DeepGO2. + """ + + _LABELS_START_IDX: int = 5 # additional esm2_embeddings column in the dataframe + _ESM_EMBEDDINGS_COL_IDX: int = 4 + + def __init__(self, use_esm2_embeddings=False, **kwargs): + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + assert int(kwargs.get("max_sequence_length")) == 1000 + self.use_esm2_embeddings: bool = use_esm2_embeddings + super(_DeepGOMigratedData, self).__init__(**kwargs) + + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: + """ + Load and process data from a file into a list of dictionaries containing features and labels. + + This method processes data differently based on the `use_esm2_embeddings` flag: + - If `use_esm2_embeddings` is True, raw dictionaries from `_load_dict` are returned, _load_dict already returns + the numerical features (esm2 embeddings) from the data file, hence no reader is required. + - Otherwise, a reader is used to process the data (generate numerical features). + + Args: + path (str): The path to the input file. + + Returns: + List[Dict[str, Any]]: A list of dictionaries with the following keys: + - `features`: Sequence or embedding data, depending on the context. + - `labels`: A boolean array of labels. + - `ident`: The identifier for the sequence. + """ + lines = self._get_data_size(path) + print(f"Processing {lines} lines...") + + if self.use_esm2_embeddings: + data = [ + d + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + else: + data = [ + self.reader.to_data(d) + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + + # filter for missing features in resulting data + data = [val for val in data if val["features"] is not None] + + return data + + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data at row index `self._ESM2_EMBEDDINGS_COL_IDX`: ESM2 embeddings of the protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels + + The method adapts based on the `use_esm2_embeddings` flag: + - If `use_esm2_embeddings` is True, features are loaded from the column specified by `self._ESM_EMBEDDINGS_COL_IDX`. + - Otherwise, features are loaded from the column specified by `self._DATA_REPRESENTATION_IDX`. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (Any): Sequence or embedding data for the instance. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + + if self.use_esm2_embeddings: + features_idx = self._ESM_EMBEDDINGS_COL_IDX + else: + features_idx = self._DATA_REPRESENTATION_IDX + + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + yield dict( + features=row[features_idx], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Returns main processed file names specific to DeepGO2. + + Returns: + dict: Dictionary with the main data file name for DeepGO2. + """ + return {"data": "data_deep_go2.pkl"} + + @property + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Returns processed file names specific to DeepGO2. + + Returns: + dict: Dictionary with data file name for DeepGO2. + """ + return {"data": "data_deep_go2.pt"} + + @property + def identifier(self) -> tuple: + """Identifier for the dataset.""" + if self.use_esm2_embeddings: + return (dr.ESM2EmbeddingReader.name(),) + return (self.reader.name(),) diff --git a/chebai/preprocessing/datasets/protein_pretraining.py b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py similarity index 97% rename from chebai/preprocessing/datasets/protein_pretraining.py rename to chebai/preprocessing/datasets/deepGO/protein_pretraining.py index 6b5d1df0..8f7e9c4d 100644 --- a/chebai/preprocessing/datasets/protein_pretraining.py +++ b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py @@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.datasets.go_uniprot import ( +from chebai.preprocessing.datasets.deepGO.go_uniprot import ( AMBIGUOUS_AMINO_ACIDS, EXPERIMENTAL_EVIDENCE_CODES, GOUniProtOver250, @@ -96,15 +96,15 @@ def _download_required_data(self) -> str: def _parse_protein_data_for_pretraining(self) -> pd.DataFrame: """ Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid - Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). + Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence codes, as specified in + `EXPERIMENTAL_EVIDENCE_CODES`. The DataFrame includes the following columns: - "swiss_id": The unique identifier for each Swiss-Prot record. - "sequence": The protein sequence. Note: - We ignore proteins with ambiguous amino acid codes (B, O, J, U, X, Z) in their sequence.` + We ignore proteins with ambiguous amino acid specified in `AMBIGUOUS_AMINO_ACIDS` in their sequence.` Returns: pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with not associated valid GO. diff --git a/chebai/preprocessing/datasets/scope/__init__.py b/chebai/preprocessing/datasets/scope/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py new file mode 100644 index 00000000..e9127b25 --- /dev/null +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -0,0 +1,972 @@ +# References for this file : + +# Reference 1: +# John-Marc Chandonia, Naomi K Fox, Steven E Brenner, SCOPe: classification of large macromolecular structures +# in the structural classification of proteins—extended database, Nucleic Acids Research, Volume 47, +# Issue D1, 08 January 2019, Pages D475–D481, https://doi.org/10.1093/nar/gky1134 +# https://scop.berkeley.edu/about/ver=2.08 + +# Reference 2: +# Murzin AG, Brenner SE, Hubbard TJP, Chothia C. 1995. SCOP: a structural classification of proteins database for +# the investigation of sequences and structures. Journal of Molecular Biology 247:536-540 + +import gzip +import os +import re +import shutil +from abc import ABC, abstractmethod +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Generator, List, Optional, Tuple + +import networkx as nx +import pandas as pd +import requests +import torch +from Bio import SeqIO + +from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai.preprocessing.reader import ProteinDataReader + + +class _SCOPeDataExtractor(_DynamicDataset, ABC): + """ + A class for extracting and processing data from the SCOPe (Structural Classification of Proteins - extended) dataset. + + This class is designed to handle the parsing, preprocessing, and hierarchical structure extraction from various + SCOPe dataset files, such as classification (CLA), hierarchy (HIE), and description (DES) files. + Additionally, it supports downloading related data like PDB sequence files. + + Args: + scope_version (str): The SCOPe version to use. + scope_version_train (Optional[str]): The training SCOPe version, if different. + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. + """ + + # -- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset`) + # "id" at row index 0 + # "sids" at row index 1 + # "sequence" at row index 2 + # labels starting from row index 3 + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 2 # here `sequence` column + _LABELS_START_IDX: int = 3 + + _SCOPE_GENERAL_URL = "https://scop.berkeley.edu/downloads/parse/dir.{data_type}.scope.{version_number}-stable.txt" + _PDB_SEQUENCE_DATA_URL = ( + "https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz" + ) + + SCOPE_HIERARCHY: Dict[str, str] = { + "cl": "class", + "cf": "fold", + "sf": "superfamily", + "fa": "family", + "dm": "protein", + "sp": "species", + "px": "domain", + } + + def __init__( + self, + scope_version: str, + scope_version_train: Optional[str] = None, + max_sequence_len: int = 1000, + **kwargs, + ): + self.scope_version: str = scope_version + self.scope_version_train: str = scope_version_train + self.max_sequence_len: int = max_sequence_len + + super(_SCOPeDataExtractor, self).__init__(**kwargs) + + if self.scope_version_train is not None: + # Instantiate another same class with "scope_version" as "scope_version_train", if train_version is given + # This is to get the data from respective directory related to "scope_version_train" + _init_kwargs = kwargs + _init_kwargs["scope_version"] = self.scope_version_train + self._scope_version_train_obj = self.__class__( + **_init_kwargs, + ) + + @staticmethod + def _get_scope_url(data_type: str, version_number: str) -> str: + """ + Generates the URL for downloading SCOPe files. + + Args: + data_type (str): The type of data (e.g., 'cla', 'hie', 'des'). + version_number (str): The version of the SCOPe file. + + Returns: + str: The formatted SCOPe file URL. + """ + return _SCOPeDataExtractor._SCOPE_GENERAL_URL.format( + data_type=data_type, version_number=version_number + ) + + # ------------------------------ Phase: Prepare data ----------------------------------- + def _download_required_data(self) -> str: + """ + Downloads the required raw data for SCOPe and PDB sequence datasets. + + Returns: + str: Path to the downloaded data. + """ + self._download_pdb_sequence_data() + return self._download_scope_raw_data() + + def _download_pdb_sequence_data(self) -> None: + """ + Downloads and unzips the PDB sequence dataset from the RCSB PDB repository. + + The file is downloaded as a temporary gzip file, which is then extracted to the + specified directory. + """ + pdb_seq_file_path = os.path.join( + self.scope_root_dir, self.raw_file_names_dict["PDB"] + ) + os.makedirs(os.path.dirname(pdb_seq_file_path), exist_ok=True) + + if not os.path.isfile(pdb_seq_file_path): + print(f"Missing PDB raw data, Downloading PDB sequence data....") + + # Create a temporary file + with NamedTemporaryFile(delete=False) as tf: + temp_filename = tf.name + print(f"Downloading to temporary file {temp_filename}") + + # Download the file + response = requests.get(self._PDB_SEQUENCE_DATA_URL, stream=True) + with open(temp_filename, "wb") as temp_file: + shutil.copyfileobj(response.raw, temp_file) + + print(f"Downloaded to {temp_filename}") + + # Unpack the gzipped file + try: + print(f"Unzipping the file....") + with gzip.open(temp_filename, "rb") as f_in: + output_file_path = pdb_seq_file_path + with open(output_file_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"Unpacked and saved to {output_file_path}") + + except Exception as e: + print(f"Failed to unpack the file: {e}") + finally: + # Clean up the temporary file + os.remove(temp_filename) + print(f"Removed temporary file {temp_filename}") + + def _download_scope_raw_data(self) -> str: + """ + Downloads the raw SCOPe dataset files (CLA, HIE, DES, and COM). + + Each file is downloaded from the SCOPe repository and saved to the specified directory. + Files are only downloaded if they do not already exist. + + Returns: + str: A dummy path to indicate completion (can be extended for custom behavior). + """ + os.makedirs(self.raw_dir, exist_ok=True) + for data_type in ["CLA", "HIE", "DES"]: + data_file_name = self.raw_file_names_dict[data_type] + scope_path = os.path.join(self.raw_dir, data_file_name) + if not os.path.isfile(scope_path): + print(f"Missing Scope: {data_file_name} raw data, Downloading...") + r = requests.get( + self._get_scope_url(data_type.lower(), self.scope_version), + allow_redirects=False, + verify=False, # Disable SSL verification + ) + r.raise_for_status() # Check if the request was successful + open(scope_path, "wb").write(r.content) + return "dummy/path" + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts the class hierarchy from SCOPe data and computes its transitive closure. + + Args: + data_path (str): Path to the processed SCOPe dataset. + + Returns: + nx.DiGraph: A directed acyclic graph representing the SCOPe class hierarchy. + """ + print("Extracting class hierarchy...") + df_scope = self._get_scope_data() + pdb_chain_df = self._parse_pdb_sequence_file() + pdb_id_set = set(pdb_chain_df["pdb_id"]) # Search time complexity - O(1) + + # Initialize sets and dictionaries for storing edges and attributes + parent_node_edges, node_child_edges = set(), set() + node_attrs = {} + px_level_nodes = set() + sequence_nodes = dict() + px_to_seq_edges = set() + required_graph_nodes = set() + + # Create a lookup dictionary for PDB chain sequences + lookup_dict = ( + pdb_chain_df.groupby("pdb_id")[["chain_id", "sequence"]] + .apply(lambda x: dict(zip(x["chain_id"], x["sequence"]))) + .to_dict() + ) + + def add_sequence_nodes_edges(chain_sequence, px_sun_id): + """Adds sequence nodes and edges connecting px-level nodes to sequence nodes.""" + if chain_sequence not in sequence_nodes: + sequence_nodes[chain_sequence] = f"seq_{len(sequence_nodes)}" + px_to_seq_edges.add((px_sun_id, sequence_nodes[chain_sequence])) + + # Step 1: Build the graph structure and store node attributes + for row in df_scope.itertuples(index=False): + if row.level == "px": + + pdb_id, chain_id = row.sid[1:5], row.sid[5] + + if pdb_id not in pdb_id_set or chain_id == "_": + # Don't add domain level nodes that don't have pdb_id in pdb_sequences.txt file + # Also chain_id with "_" which corresponds to no chain + continue + px_level_nodes.add(row.sunid) + + # Add edges between px-level nodes and sequence nodes + if chain_id != ".": + if chain_id not in lookup_dict[pdb_id]: + continue + add_sequence_nodes_edges(lookup_dict[pdb_id][chain_id], row.sunid) + else: + # If chain_id is '.', connect all chains of this PDB ID + for chain, chain_sequence in lookup_dict[pdb_id].items(): + add_sequence_nodes_edges(chain_sequence, row.sunid) + else: + required_graph_nodes.add(row.sunid) + + node_attrs[row.sunid] = {"sid": row.sid, "level": row.level} + + if row.parent_sunid != -1: + parent_node_edges.add((row.parent_sunid, row.sunid)) + + for child_id in row.children_sunids: + node_child_edges.add((row.sunid, child_id)) + + del df_scope, pdb_chain_df, pdb_id_set + + g = nx.DiGraph() + g.add_nodes_from(node_attrs.items()) + # Note - `add_edges` internally create a node, if a node doesn't exist already + g.add_edges_from({(p, c) for p, c in parent_node_edges if p in node_attrs}) + g.add_edges_from({(p, c) for p, c in node_child_edges if c in node_attrs}) + + seq_nodes = set(sequence_nodes.values()) + g.add_nodes_from([(seq_id, {"level": "sequence"}) for seq_id in seq_nodes]) + g.add_edges_from( + { + (px_node, seq_node) + for px_node, seq_node in px_to_seq_edges + if px_node in node_attrs and seq_node in seq_nodes + } + ) + + # Step 2: Count sequence successors for required graph nodes only + for node in required_graph_nodes: + num_seq_successors = sum( + g.nodes[child]["level"] == "sequence" + for child in nx.descendants(g, node) + ) + g.nodes[node]["num_seq_successors"] = num_seq_successors + + # Step 3: Remove nodes which are not required before computing transitive closure for better efficiency + g.remove_nodes_from(px_level_nodes | seq_nodes) + + print("Computing Transitive Closure.........") + # Transitive closure is not needed in `select_classes` method but is required in _SCOPeOverXPartial + return nx.transitive_closure_dag(g) + + def _get_scope_data(self) -> pd.DataFrame: + """ + Merges and preprocesses the SCOPe classification, hierarchy, and description files into a unified DataFrame. + + Returns: + pd.DataFrame: A DataFrame containing combined SCOPe data with classification and hierarchy details. + """ + df_cla = self._get_classification_data() + df_hie = self._get_hierarchy_data() + df_des = self._get_node_description_data() + df_hie_with_cla = pd.merge(df_hie, df_cla, how="left", on="sunid") + df_all = pd.merge( + df_hie_with_cla, + df_des.drop(columns=["sid"], axis=1), + how="left", + on="sunid", + ) + return df_all + + def _get_classification_data(self) -> pd.DataFrame: + """ + Parses and processes the SCOPe CLA (classification) file. + + Returns: + pd.DataFrame: A DataFrame containing classification details, including hierarchy levels. + """ + df_cla = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["CLA"]), + sep="\t", + header=None, + comment="#", + ) + df_cla.columns = [ + "sid", + "PDB_ID", + "description", + "sccs", + "sunid", + "hie_levels", + ] + + # Convert to dict - {cl:46456, cf:46457, sf:46458, fa:46459, dm:46460, sp:116748, px:113449} + df_cla["hie_levels"] = df_cla["hie_levels"].apply( + lambda x: {k: int(v) for k, v in (item.split("=") for item in x.split(","))} + ) + + # Split ancestor_nodes into separate columns and assign values + for key in self.SCOPE_HIERARCHY.keys(): + df_cla[self.SCOPE_HIERARCHY[key]] = df_cla["hie_levels"].apply( + lambda x: x[key] + ) + + df_cla["sunid"] = df_cla["sunid"].astype("int64") + + return df_cla + + def _get_hierarchy_data(self) -> pd.DataFrame: + """ + Parses and processes the SCOPe HIE (hierarchy) file. + + Returns: + pd.DataFrame: A DataFrame containing hierarchy details, including parent-child relationships. + """ + df_hie = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["HIE"]), + sep="\t", + header=None, + comment="#", + low_memory=False, + ) + df_hie.columns = ["sunid", "parent_sunid", "children_sunids"] + + # if not parent id, then insert -1 + df_hie["parent_sunid"] = df_hie["parent_sunid"].replace("-", -1).astype(int) + # convert children ids to list of ids + df_hie["children_sunids"] = df_hie["children_sunids"].apply( + lambda x: list(map(int, x.split(","))) if x != "-" else [] + ) + + # Ensure the 'sunid' column in both DataFrames has the same type + df_hie["sunid"] = df_hie["sunid"].astype("int64") + return df_hie + + def _get_node_description_data(self) -> pd.DataFrame: + """ + Parses and processes the SCOPe DES (description) file. + + Returns: + pd.DataFrame: A DataFrame containing node-level descriptions from the SCOPe dataset. + """ + df_des = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["DES"]), + sep="\t", + header=None, + comment="#", + low_memory=False, + ) + df_des.columns = ["sunid", "level", "scss", "sid", "description"] + df_des.loc[len(df_des)] = {"sunid": 0, "level": "root"} + + # Ensure the 'sunid' column in both DataFrames has the same type + df_des["sunid"] = df_des["sunid"].astype("int64") + return df_des + + def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + """ + Processes a directed acyclic graph (DAG) to generate a raw dataset in DataFrame format. This dataset includes + chain-level sequences and their corresponding labels based on the hierarchical structure of the associated domains. + + The process: + - Extracts SCOPe domain identifiers (sids) from the graph. + - Retrieves class labels for each domain based on all applicable taxonomy levels. + - Fetches the chain-level sequences from the Protein Data Bank (PDB) for each domain. + - For each sequence, identifies all domains associated with the same chain and assigns their corresponding labels. + + Notes: + - SCOPe hierarchy levels are used as labels, with each level represented by a column. The value in each column + indicates whether a PDB chain is associated with that particular hierarchy level. + - PDB chains are treated as samples. The method considers only domains that are mapped to the selected hierarchy levels. + + Data Format: pd.DataFrame + - Column 0 : id (Unique identifier for each sequence entry) + - Column 1 : sids (List of domain identifiers associated with the sequence) + - Column 2 : sequence (Amino acid sequence of the chain) + - Column 3 to Column "n": Each column corresponds to a SCOPe class hierarchy level with a value + of True/False indicating whether the chain is associated with the corresponding level. + + Args: + graph (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset created from the graph. + + Raises: + RuntimeError: If no sunids are selected. + """ + print(f"Process graph") + + selected_sun_ids_per_lvl = self.select_classes(graph) + + if not selected_sun_ids_per_lvl: + raise RuntimeError("No sunid selected.") + + df_cla = self._get_classification_data() + hierarchy_levels = list(self.SCOPE_HIERARCHY.values()) + hierarchy_levels.remove("domain") + + df_cla = df_cla[["sid", "sunid"] + hierarchy_levels] + + # Initialize selected target columns + df_encoded = df_cla[["sid", "sunid"]].copy() + + # Collect all new columns in a dictionary first (avoids fragmentation) + encoded_df_columns = {} + + lvl_to_target_cols_mapping = {} + # Iterate over only the selected sun_ids (nodes) to one-hot encode them + for level, selected_sun_ids in selected_sun_ids_per_lvl.items(): + level_column = self.SCOPE_HIERARCHY[level] + if level_column in df_cla.columns: + # Create binary encoding for only relevant sun_ids + for sun_id in selected_sun_ids: + col_name = f"{level_column}_{sun_id}" + encoded_df_columns[col_name] = ( + df_cla[level_column] == sun_id + ).astype(bool) + + lvl_to_target_cols_mapping.setdefault(level_column, []).append( + col_name + ) + + # Convert the dictionary into a DataFrame and concatenate at once (prevents fragmentation) + df_encoded = pd.concat([df_encoded, pd.DataFrame(encoded_df_columns)], axis=1) + + encoded_target_columns = [] + for level in hierarchy_levels: + if level in lvl_to_target_cols_mapping: + encoded_target_columns.extend(lvl_to_target_cols_mapping[level]) + + print( + f"{len(encoded_target_columns)} labels has been selected for specified threshold, " + ) + print("Constructing data.pkl file .....") + + df_encoded = df_encoded[["sid", "sunid"] + encoded_target_columns] + + # Filter to select only domains that atleast map to any one selected sunid in any level + df_encoded = df_encoded[df_encoded.iloc[:, 2:].any(axis=1)] + + df_encoded["pdb_id"] = df_encoded["sid"].str[1:5] + df_encoded["chain_id"] = df_encoded["sid"].str[5] + + # "_" (underscore) means it has no chain + df_encoded = df_encoded[df_encoded["chain_id"] != "_"] + + pdb_chain_df = self._parse_pdb_sequence_file() + + # Handle chain_id == "." - Multiple chain case + # Split df_encoded into two: One for specific chains, one for "multiple chains" (".") + df_specific_chains = df_encoded[df_encoded["chain_id"] != "."] + df_multiple_chains = df_encoded[df_encoded["chain_id"] == "."].drop( + columns=["chain_id"] + ) + + # Merge specific chains normally + merged_specific = df_specific_chains.merge( + pdb_chain_df, on=["pdb_id", "chain_id"], how="left" + ) + + # Merge all chains case -> Join by pdb_id (not chain_id) + merged_all_chains = df_multiple_chains.merge( + pdb_chain_df, on="pdb_id", how="left" + ) + + # Combine both cases + sequence_hierarchy_df = pd.concat( + [merged_specific, merged_all_chains], ignore_index=True + ).dropna(subset=["sequence"]) + + # Vectorized Aggregation Instead of Row-wise Updates + sequence_hierarchy_df = ( + sequence_hierarchy_df.groupby("sequence", as_index=False) + .agg( + { + "sid": list, # Collect all SIDs per sequence + **{ + col: "max" for col in encoded_target_columns + }, # Max works as Bitwise OR for labels + } + ) + .rename(columns={"sid": "sids"}) + ) # Rename for clarity + + sequence_hierarchy_df = sequence_hierarchy_df.assign( + id=range(1, len(sequence_hierarchy_df) + 1) + )[["id", "sids", "sequence"] + encoded_target_columns] + + # Ensure atleast one label is true for each protein sequence + sequence_hierarchy_df = sequence_hierarchy_df[ + sequence_hierarchy_df.iloc[:, self._LABELS_START_IDX :].any(axis=1) + ] + + with open(os.path.join(self.processed_dir_main, "classes.txt"), "wt") as fout: + fout.writelines(str(sun_id) + "\n" for sun_id in encoded_target_columns) + + return sequence_hierarchy_df + + def _parse_pdb_sequence_file(self) -> pd.DataFrame: + """ + Parses the PDB sequence file and returns a DataFrame containing PDB IDs, chain IDs, and sequences. + + Returns: + pd.DataFrame: A DataFrame with columns ["pdb_id", "chain_id", "sequence"]. + """ + records = [] + valid_amino_acids = "".join(ProteinDataReader.AA_LETTER) + + for record in SeqIO.parse( + os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta" + ): + + if not record.seq or len(record.seq) > self.max_sequence_len: + continue + + pdb_id, chain = record.id.split("_") + sequence = re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq)) + + # Store as a dictionary entry (list of dicts -> DataFrame later) + records.append( + { + "pdb_id": pdb_id.lower(), + "chain_id": chain.lower(), + "sequence": sequence, + } + ) + + # Convert list of dictionaries to a DataFrame + pdb_chain_df = pd.DataFrame.from_records(records) + + return pdb_chain_df + + @abstractmethod + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]: + # Override the return type of the method from superclass + pass + + # ------------------------------ Phase: Setup data ----------------------------------- + def setup_processed(self) -> None: + """ + Transform and prepare processed data for the SCOPe dataset. + + Main function of this method is to transform `data.pkl` into a model input data format (`data.pt`), + ensuring that the data is in a format compatible for input to the model. + The transformed data must contain the following keys: `ident`, `features`, `labels`, and `group`. + This method uses a subclass of Data Reader to perform the transformation. + + It will transform the data related to `scope_version_train`, if specified. + """ + super().setup_processed() + + # Transform the data related to "scope_version_train" to encoded data, if it doesn't exist + if self.scope_version_train is not None and not os.path.isfile( + os.path.join( + self._scope_version_train_obj.processed_dir, + self._scope_version_train_obj.processed_file_names_dict["data"], + ) + ): + print( + f"Missing encoded data related to train version: {self.scope_version_train}" + ) + print("Calling the setup method related to it") + self._scope_version_train_obj.setup() + + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels + + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (str): The sequence data from the file. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group + # "group" set to None, by default as no such entity for this data + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Loads encoded/transformed data and generates training, validation, and test splits. + + This method first loads encoded data from a file named `data.pt`, which is derived from either + `scope_version` or `scope_version_train`. It then splits the data into training, validation, and test sets. + + If `scope_version_train` is provided: + - Loads additional encoded data from `scope_version_train`. + - Splits this data into training and validation sets, while using the test set from `scope_version`. + - Prunes the test set from `scope_version` to include only labels that exist in `scope_version_train`. + + If `scope_version_train` is not provided: + - Splits the data from `scope_version` into training, validation, and test sets without modification. + + Raises: + FileNotFoundError: If the required `data.pt` file(s) do not exist. Ensure that `prepare_data` + and/or `setup` methods have been called to generate the dataset files. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: + - Training set + - Validation set + - Test set + """ + try: + filename = self.processed_file_names_dict["data"] + data_scope_version = torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists. " + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_scope_version = pd.DataFrame(data_scope_version) + train_df_scope_ver, df_test_scope_ver = self.get_test_split( + df_scope_version, seed=self.dynamic_data_split_seed + ) + + if self.scope_version_train is not None: + # Load encoded data derived from "scope_version_train" + try: + filename_train = ( + self._scope_version_train_obj.processed_file_names_dict["data"] + ) + data_scope_train_version = torch.load( + os.path.join( + self._scope_version_train_obj.processed_dir, filename_train + ), + weights_only=False, + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists related to scope_version_train {self.scope_version_train}." + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_scope_train_version = pd.DataFrame(data_scope_train_version) + # Get train/val split of data based on "scope_version_train", but + # using test set from "scope_version" + df_train, df_val = self.get_train_val_splits_given_test( + df_scope_train_version, + df_test_scope_ver, + seed=self.dynamic_data_split_seed, + ) + # Modify test set from "scope_version" to only include the labels that + # exists in "scope_version_train", all other entries remains same. + df_test = self._setup_pruned_test_set(df_test_scope_ver) + else: + # Get all splits based on "scope_version" + df_train, df_val = self.get_train_val_splits_given_test( + train_df_scope_ver, + df_test_scope_ver, + seed=self.dynamic_data_split_seed, + ) + df_test = df_test_scope_ver + + return df_train, df_val, df_test + + def _setup_pruned_test_set( + self, df_test_scope_version: pd.DataFrame + ) -> pd.DataFrame: + """ + Create a test set with the same leaf nodes, but use only classes that appear in the training set. + + Args: + df_test_scope_version (pd.DataFrame): The test dataset. + + Returns: + pd.DataFrame: The pruned test dataset. + """ + # TODO: find a more efficient way to do this + filename_old = "classes.txt" + # filename_new = f"classes_v{self.scope_version_train}.txt" + # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) + + # Load original classes (from the current SCOPe version - scope_version) + with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: + orig_classes = file.readlines() + + # Load new classes (from the training SCOPe version - scope_version_train) + with open( + os.path.join( + self._scope_version_train_obj.processed_dir_main, filename_old + ), + "r", + ) as file: + new_classes = file.readlines() + + # Create a mapping which give index of a class from scope_version, if the corresponding + # class exists in scope_version_train, Size = Number of classes in scope_version + mapping = [ + None if or_class not in new_classes else new_classes.index(or_class) + for or_class in orig_classes + ] + + # Iterate over each data instance in the test set which is derived from scope_version + for _, row in df_test_scope_version.iterrows(): + # Size = Number of classes in scope_version_train + new_labels = [False for _ in new_classes] + for ind, label in enumerate(row["labels"]): + # If the scope_version class exists in the scope_version_train and has a True label, + # set the corresponding label in new_labels to True + if mapping[ind] is not None and label: + new_labels[mapping[ind]] = label + # Update the labels from test instance from scope_version to the new labels, which are compatible to both versions + row["labels"] = new_labels + + return df_test_scope_version + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def scope_root_dir(self) -> str: + """ + Returns the root directory of scope data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ + return os.path.join("data", "SCOPe") + + @property + def base_dir(self) -> str: + """ + Returns the base directory path for storing SCOPe data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ + return os.path.join(self.scope_root_dir, f"version_{self.scope_version}") + + @property + def raw_file_names_dict(self) -> dict: + """ + Returns a dictionary of raw file names used in data processing. + + Returns: + dict: A dictionary mapping dataset names to their respective file names. + """ + return { + "CLA": "cla.txt", + "DES": "des.txt", + "HIE": "hie.txt", + "PDB": "pdb_sequences.txt", + } + + +class _SCOPeOverX(_SCOPeDataExtractor, ABC): + """ + A class for extracting data from the SCOPe dataset with a threshold for selecting classes/labels based on + the number of subclasses. + + This class is designed to filter SCOPe classes/labels based on a specified threshold, selecting only those classes + which have a certain number of subclasses in the hierarchy. + + Attributes: + READER (dr.ProteinDataReader): The reader used for reading the dataset. + THRESHOLD (int): The threshold for selecting classes/labels based on the number of subclasses. + + """ + + READER = ProteinDataReader + THRESHOLD: int = None + + @property + def _name(self) -> str: + """ + Returns the name of the dataset. + + Returns: + str: The dataset name, formatted with the current threshold. + """ + return f"SCOPe{self.THRESHOLD}" + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]: + """ + Selects classes from the SCOPe dataset based on the number of successors meeting a specified threshold. + + This method iterates over the nodes in the graph, counting the number of successors for each node. + Nodes with a number of successors greater than or equal to the defined threshold are selected. + + Note: + The input graph must be transitive closure of a directed acyclic graph. + + Args: + g (nx.Graph): The graph representing the dataset. + *args: Additional positional arguments (not used). + **kwargs: Additional keyword arguments (not used). + + Returns: + Dict: A dict containing selected nodes at each hierarchy level. + + Notes: + - The `THRESHOLD` attribute should be defined in the subclass of this class. + """ + selected_sunids_for_level = {} + for node, attr_dict in g.nodes(data=True): + if attr_dict["level"] in {"root", "px", "sequence"}: + # Skip nodes with level "root", "px", or "sequence" + continue + + # Check if the number of "sequence"-level successors meets or exceeds the threshold + if g.nodes[node]["num_seq_successors"] >= self.THRESHOLD: + selected_sunids_for_level.setdefault(attr_dict["level"], []).append( + node + ) + return selected_sunids_for_level + + +class _SCOPeOverXPartial(_SCOPeOverX, ABC): + """ + Dataset that doesn't use the full SCOPe dataset, but extracts a part of SCOPe (subclasses of a given top class) + + Attributes: + top_class_sunid (int): The Sun-ID of the top class from which to extract subclasses. + """ + + def __init__(self, top_class_sunid: int, **kwargs): + """ + Initializes the _SCOPeOverXPartial dataset. + + Args: + top_class_sunid (int): The Sun-ID of the top class from which to extract subclasses. + **kwargs: Additional keyword arguments passed to the superclass initializer. + """ + if "top_class_sunid" not in kwargs: + kwargs["top_class_sunid"] = top_class_sunid + + self.top_class_sunid: int = top_class_sunid + super().__init__(**kwargs) + + @property + def processed_dir_main(self) -> str: + """ + Returns the main processed data directory specific to the top class. + + Returns: + str: The processed data directory path. + """ + return os.path.join( + self.base_dir, + self._name, + f"partial_{self.top_class_sunid}", + "processed", + ) + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts a subset of SCOPe based on subclasses of the top class ID. + + This method calls the superclass method to extract the full class hierarchy, + then extracts the subgraph containing only the descendants of the top class ID, including itself. + + Args: + data_path (str): The file path to the SCOPe ontology file. + + Returns: + nx.DiGraph: The extracted class hierarchy as a directed graph, limited to the + descendants of the top class ID. + """ + g = super()._extract_class_hierarchy(data_path) + g = g.subgraph( + list(g.successors(self.top_class_sunid)) + [self.top_class_sunid] + ) + return g + + +class SCOPeOver2000(_SCOPeOverX): + """ + A class for extracting data from the SCOPe dataset with a threshold of 2000 for selecting classes. + + Inherits from `_SCOPeOverX` and sets the threshold for selecting classes to 2000. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (2000). + """ + + THRESHOLD: int = 2000 + + +class SCOPeOver50(_SCOPeOverX): + + THRESHOLD = 50 + + +class SCOPeOverPartial2000(_SCOPeOverXPartial): + """ + A class for extracting data from the SCOPe dataset with a threshold of 2000 for selecting classes. + + Inherits from `_SCOPeOverXPartial` and sets the threshold for selecting classes to 2000. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (2000). + """ + + THRESHOLD: int = 2000 + + +if __name__ == "__main__": + scope = SCOPeOver50(scope_version="2.08") + + # g = scope._extract_class_hierarchy("dummy/path") + # # Save graph + # import pickle + # with open("graph.gpickle", "wb") as f: + # pickle.dump(g, f) + + # Load graph + import pickle + + with open("graph.gpickle", "rb") as f: + g = pickle.load(f) + + # print(len([node for node in g.nodes() if g.out_degree(node) > 10000])) + scope._graph_to_raw_dataset(g) diff --git a/chebai/preprocessing/migration/deep_go/__init__.py b/chebai/preprocessing/migration/deep_go/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py new file mode 100644 index 00000000..7d59c699 --- /dev/null +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -0,0 +1,316 @@ +import os +from collections import OrderedDict +from typing import List, Literal, Optional, Tuple + +import pandas as pd +from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit +from jsonargparse import CLI + +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO1MigratedData + + +class DeepGo1DataMigration: + """ + A class to handle data migration and processing for the DeepGO project. + It migrates the DeepGO data to our data structure followed for GO-UniProt data. + + This class handles migration of data from the DeepGO paper below: + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 + (https://doi.org/10.1093/bioinformatics/btx624). + """ + + # Max sequence length as per DeepGO1 + _MAXLEN = 1002 + _LABELS_START_IDX = DeepGO1MigratedData._LABELS_START_IDX + + def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + """ + Initializes the data migration object with a data directory and GO branch. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. + """ + valid_go_branches = list(DeepGO1MigratedData.GO_BRANCH_MAPPING.keys()) + if go_branch not in valid_go_branches: + raise ValueError(f"go_branch must be one of {valid_go_branches}") + self._go_branch = go_branch + + self._data_dir: str = rf"{data_dir}" + self._train_df: Optional[pd.DataFrame] = None + self._test_df: Optional[pd.DataFrame] = None + self._validation_df: Optional[pd.DataFrame] = None + self._terms_df: Optional[pd.DataFrame] = None + self._classes: Optional[List[str]] = None + + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + data_with_labels_df = self._extract_required_data_from_splits() + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + + def _load_data(self) -> None: + """ + Loads the test, train, validation, and terms data from the pickled files + in the data directory. + """ + try: + print(f"Loading data files from directory: {self._data_dir}") + self._test_df = pd.DataFrame( + pd.read_pickle( + os.path.join(self._data_dir, f"test-{self._go_branch}.pkl") + ) + ) + + # DeepGO 1 lacks a validation split, so we will create one by further splitting the training set. + # Although this reduces the training data slightly compared to the original DeepGO setup, + # given the data size, the impact should be minimal. + train_df = pd.DataFrame( + pd.read_pickle( + os.path.join(self._data_dir, f"train-{self._go_branch}.pkl") + ) + ) + + self._train_df, self._validation_df = self._get_train_val_split(train_df) + + self._terms_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) + ) + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Data file not found in directory: {e}. " + "Please ensure all required files are available in the specified directory." + ) + + @staticmethod + def _get_train_val_split( + train_df: pd.DataFrame, + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Splits the training data into a smaller training set and a validation set. + + Args: + train_df (pd.DataFrame): Original training DataFrame. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: Training and validation DataFrames. + """ + labels_list_train = train_df["labels"].tolist() + train_split = 0.85 + test_size = ((1 - train_split) ** 2) / train_split + + splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=42 + ) + + train_indices, validation_indices = next( + splitter.split(labels_list_train, labels_list_train) + ) + + df_validation = train_df.iloc[validation_indices] + df_train = train_df.iloc[train_indices] + return df_train, df_validation + + def _record_splits(self) -> pd.DataFrame: + """ + Creates a DataFrame that stores the IDs and their corresponding data splits. + + Returns: + pd.DataFrame: A combined DataFrame containing split assignments. + """ + print("Recording data splits for train, validation, and test sets.") + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), + pd.DataFrame( + {"id": self._validation_df["proteins"], "split": "validation"} + ), + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + return combined_split_assignment + + def _extract_required_data_from_splits(self) -> pd.DataFrame: + """ + Extracts required columns from the combined data splits. + + Returns: + pd.DataFrame: A DataFrame containing the essential columns for processing. + """ + print("Combining data splits into a single DataFrame with required columns.") + required_columns = [ + "proteins", + "accessions", + "sequences", + "gos", + "labels", + ] + + new_df = pd.concat( + [ + self._train_df[required_columns], + self._validation_df[required_columns], + self._test_df[required_columns], + ], + ignore_index=True, + ) + new_df["go_ids"] = new_df.apply( + lambda row: self.extract_go_id(row["gos"]), axis=1 + ) + + labels_df = self._get_labels_columns(new_df) + + data_df = pd.DataFrame( + OrderedDict( + swiss_id=new_df["proteins"], + accession=new_df["accessions"], + go_ids=new_df["go_ids"], + sequence=new_df["sequences"], + ) + ) + + df = pd.concat([data_df, labels_df], axis=1) + + return df + + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[int]: List of parsed GO IDs. + """ + return [DeepGO1MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + + def _get_labels_columns(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates columns for labels based on provided selected terms. + + Args: + data_df (pd.DataFrame): DataFrame with GO annotations and labels. + + Returns: + pd.DataFrame: DataFrame with label columns. + """ + print("Generating label columns from provided selected terms.") + parsed_go_ids: pd.Series = self._terms_df["functions"].apply( + lambda gos: DeepGO1MigratedData._parse_go_id(gos) + ) + all_go_ids_list = parsed_go_ids.values.tolist() + self._classes = all_go_ids_list + + new_label_columns = pd.DataFrame( + data_df["labels"].tolist(), index=data_df.index, columns=all_go_ids_list + ) + + return new_label_columns + + def save_migrated_data( + self, data_df: pd.DataFrame, splits_df: pd.DataFrame + ) -> None: + """ + Saves the processed data and split information. + + Args: + data_df (pd.DataFrame): Data with GO labels. + splits_df (pd.DataFrame): Split assignment DataFrame. + """ + print("Saving transformed data files.") + + deepgo_migr_inst: DeepGO1MigratedData = DeepGO1MigratedData( + go_branch=DeepGO1MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._MAXLEN, + ) + + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] + ) + print( + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" + ) + + # Save splits file + splits_df.to_csv( + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go1.csv"), + index=False, + ) + print(f"splits_deep_go1.csv saved to {deepgo_migr_inst.processed_dir_main}") + + # Save classes file + classes = sorted(self._classes) + with open( + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go1.txt"), + "wt", + ) as fout: + fout.writelines(str(node) + "\n" for node in classes) + print(f"classes_deep_go1.txt saved to {deepgo_migr_inst.processed_dir_main}") + + print("Migration process completed!") + + +class Main: + """ + Main class to handle the migration process for DeepGo1DataMigration. + + Methods: + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + Initiates the migration process for the specified data directory and GO branch. + """ + + @staticmethod + def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: + """ + Initiates the migration process by creating a DeepGoDataMigration instance + and invoking its migrate method. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use + ("cc" for cellular_component, + "mf" for molecular_function, + or "bp" for biological_process). + """ + DeepGo1DataMigration(data_dir, go_branch).migrate() + + +if __name__ == "__main__": + # Example: python script_name.py migrate --data_dir="data/deep_go1" --go_branch="mf" + # --data_dir specifies the directory containing the data files. + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. + CLI( + Main, + description="DeepGo1DataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + as_positional=False, + ) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py new file mode 100644 index 00000000..d23247c0 --- /dev/null +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -0,0 +1,366 @@ +import os +import re +from collections import OrderedDict +from typing import List, Literal, Optional + +import pandas as pd +from jsonargparse import CLI + +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO2MigratedData +from chebai.preprocessing.reader import ProteinDataReader + + +class DeepGo2DataMigration: + """ + A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE + data structure to our data structure followed for GO-UniProt data. + + This class handles migration of data from the DeepGO paper below: + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 + (https://doi.org/10.1093/bioinformatics/btx624) + """ + + _LABELS_START_IDX = DeepGO2MigratedData._LABELS_START_IDX + + def __init__( + self, data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000 + ): + """ + Initializes the data migration object with a data directory and GO branch. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. + max_len (int): Used to truncate the sequence to this length. Default is 1000. + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + """ + valid_go_branches = list(DeepGO2MigratedData.GO_BRANCH_MAPPING.keys()) + if go_branch not in valid_go_branches: + raise ValueError(f"go_branch must be one of {valid_go_branches}") + self._go_branch = go_branch + + self._data_dir: str = os.path.join(rf"{data_dir}", go_branch) + self._max_len: int = max_len + + self._train_df: Optional[pd.DataFrame] = None + self._test_df: Optional[pd.DataFrame] = None + self._validation_df: Optional[pd.DataFrame] = None + self._terms_df: Optional[pd.DataFrame] = None + self._classes: Optional[List[str]] = None + + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + + data_df = self._extract_required_data_from_splits() + data_with_labels_df = self._generate_labels(data_df) + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + + def _load_data(self) -> None: + """ + Loads the test, train, validation, and terms data from the pickled files + in the data directory. + """ + + try: + print(f"Loading data from directory: {self._data_dir}......") + + print( + "Pre-processing the data before loading them into instance variables\n" + f"2-Steps preprocessing: \n" + f"\t 1: Truncating every sequence to {self._max_len}\n" + f"\t 2: Replacing every amino acid which is not in {ProteinDataReader.AA_LETTER}" + ) + + self._test_df = self._pre_process_data( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) + ) + ) + self._train_df = self._pre_process_data( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) + ) + ) + self._validation_df = self._pre_process_data( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) + ) + ) + + self._terms_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) + ) + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Data file not found in directory: {e}. " + "Please ensure all required files are available in the specified directory." + ) + + def _pre_process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Pre-processes the input dataframe by truncating sequences to the maximum + length and replacing invalid amino acids with 'X'. + + Args: + df (pd.DataFrame): The dataframe to preprocess. + + Returns: + pd.DataFrame: The processed dataframe. + """ + df = self._truncate_sequences(df) + df = self._replace_invalid_amino_acids(df) + return df + + def _truncate_sequences( + self, df: pd.DataFrame, column: str = "sequences" + ) -> pd.DataFrame: + """ + Truncate sequences in a specified column of a dataframe to the maximum length. + + https://github.com/bio-ontology-research-group/deepgo2/blob/main/train_cnn.py#L206-L217 + + Args: + df (pd.DataFrame): The input dataframe containing the data to be processed. + column (str, optional): The column containing sequences to truncate. + Defaults to "sequences". + + Returns: + pd.DataFrame: The dataframe with sequences truncated to `self._max_len`. + """ + df[column] = df[column].apply(lambda x: x[: self._max_len]) + return df + + @staticmethod + def _replace_invalid_amino_acids( + df: pd.DataFrame, column: str = "sequences" + ) -> pd.DataFrame: + """ + Replaces invalid amino acids in a sequence with 'X' using regex. + + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L26-L33 + https://github.com/ChEB-AI/python-chebai/pull/64#issuecomment-2517067073 + + Args: + df (pd.DataFrame): The dataframe containing the sequences to be processed. + column (str, optional): The column containing the sequences. Defaults to "sequences". + + Returns: + pd.DataFrame: The dataframe with invalid amino acids replaced by 'X'. + """ + valid_amino_acids = "".join(ProteinDataReader.AA_LETTER) + # Replace any character not in the valid set with 'X' + df[column] = df[column].apply( + lambda x: re.sub(f"[^{valid_amino_acids}]", "X", x) + ) + return df + + def _record_splits(self) -> pd.DataFrame: + """ + Creates a DataFrame that stores the IDs and their corresponding data splits. + + Returns: + pd.DataFrame: A combined DataFrame containing split assignments. + """ + print("Recording data splits for train, validation, and test sets.") + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), + pd.DataFrame( + {"id": self._validation_df["proteins"], "split": "validation"} + ), + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + return combined_split_assignment + + def _extract_required_data_from_splits(self) -> pd.DataFrame: + """ + Extracts required columns from the combined data splits. + + Returns: + pd.DataFrame: A DataFrame containing the essential columns for processing. + """ + print("Combining the data splits with required data..... ") + required_columns = [ + "proteins", + "accessions", + "sequences", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 + "prop_annotations", # Direct and Transitively associated GO ids + "esm2", + ] + + new_df = pd.concat( + [ + self._train_df[required_columns], + self._validation_df[required_columns], + self._test_df[required_columns], + ], + ignore_index=True, + ) + new_df["go_ids"] = new_df["prop_annotations"].apply( + lambda x: self.extract_go_id(x) + ) + + data_df = pd.DataFrame( + OrderedDict( + swiss_id=new_df["proteins"], + accession=new_df["accessions"], + go_ids=new_df["go_ids"], + sequence=new_df["sequences"], + esm2_embeddings=new_df["esm2"], + ) + ) + return data_df + + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[str]: List of parsed GO IDs. + """ + return [DeepGO2MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + + def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates label columns for each GO term in the dataset. + + Args: + data_df (pd.DataFrame): DataFrame containing data with GO IDs. + + Returns: + pd.DataFrame: DataFrame with new label columns. + """ + print("Generating labels based on terms.pkl file.......") + parsed_go_ids: pd.Series = self._terms_df["gos"].apply( + DeepGO2MigratedData._parse_go_id + ) + all_go_ids_list = parsed_go_ids.values.tolist() + self._classes = all_go_ids_list + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=all_go_ids_list + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) + + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + def save_migrated_data( + self, data_df: pd.DataFrame, splits_df: pd.DataFrame + ) -> None: + """ + Saves the processed data and split information. + + Args: + data_df (pd.DataFrame): Data with GO labels. + splits_df (pd.DataFrame): Split assignment DataFrame. + """ + print("Saving transformed data......") + deepgo_migr_inst: DeepGO2MigratedData = DeepGO2MigratedData( + go_branch=DeepGO2MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._max_len, + ) + + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] + ) + print( + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" + ) + + # Save split file + splits_df.to_csv( + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go2.csv"), + index=False, + ) + print(f"splits_deep_go2.csv saved to {deepgo_migr_inst.processed_dir_main}") + + # Save classes.txt file + classes = sorted(self._classes) + with open( + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go2.txt"), + "wt", + ) as fout: + fout.writelines(str(node) + "\n" for node in classes) + print(f"classes_deep_go2.txt saved to {deepgo_migr_inst.processed_dir_main}") + + print("Migration completed!") + + +class Main: + """ + Main class to handle the migration process for DeepGoDataMigration. + + Methods: + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + Initiates the migration process for the specified data directory and GO branch. + """ + + @staticmethod + def migrate( + data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000 + ) -> None: + """ + Initiates the migration process by creating a DeepGoDataMigration instance + and invoking its migrate method. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use + ("cc" for cellular_component, + "mf" for molecular_function, + or "bp" for biological_process). + max_len (int): Used to truncate the sequence to this length. Default is 1000. + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + """ + DeepGo2DataMigration(data_dir, go_branch, max_len).migrate() + + +if __name__ == "__main__": + # Example: python script_name.py migrate --data_dir="data/deep_go_se_training_data" --go_branch="bp" + # --data_dir specifies the directory containing the data files. + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. + CLI( + Main, + description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + as_positional=False, + ) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index e220e1e4..7e943eb5 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -1,8 +1,18 @@ import os -from typing import Any, Dict, List, Optional +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from urllib.error import HTTPError import deepsmiles import selfies as sf +import torch +from esm import Alphabet +from esm.model.esm2 import ESM2 +from esm.pretrained import ( + _has_regression_weights, + load_model_and_alphabet_core, + load_model_and_alphabet_local, +) from pysmiles.read_smiles import _tokenize from transformers import RobertaTokenizerFast @@ -348,7 +358,7 @@ class ProteinDataReader(DataReader): COLLATOR = RaggedCollator - # 20 natural amino acid notation + # 21 natural amino acid notation AA_LETTER = [ "A", "R", @@ -370,6 +380,8 @@ class ProteinDataReader(DataReader): "W", "Y", "V", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5 + "X", # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py ] def name(self) -> str: @@ -469,3 +481,249 @@ def on_finish(self) -> None: print(f"Saving {len(self.cache)} tokens to {self.token_path}...") print(f"First 10 tokens: {self.cache[:10]}") pk.writelines([f"{c}\n" for c in self.cache]) + + +class ESM2EmbeddingReader(DataReader): + """ + A data reader to process protein sequences using the ESM2 model for embeddings. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py + + Note: + For layer availability by model, Please check below link: + https://github.com/facebookresearch/esm?tab=readme-ov-file#pre-trained-models- + + To test this reader, try lighter models: + esm2_t6_8M_UR50D: 6 layers (valid layers: 1–6), (~28 Mb) - A tiny 8M parameter model. + esm2_t12_35M_UR50D: 12 layers (valid layers: 1–12), (~128 Mb) - A slightly larger, 35M parameter model. + These smaller models are good for testing and debugging purposes. + + """ + + # https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L53 + _MODELS_URL = "https://dl.fbaipublicfiles.com/fair-esm/models/{}.pt" + _REGRESSION_URL = ( + "https://dl.fbaipublicfiles.com/fair-esm/regression/{}-contact-regression.pt" + ) + + def __init__( + self, + save_model_dir: str = os.path.join("data", "esm2_reader"), + model_name: str = "esm2_t36_3B_UR50D", + device: Optional[torch.device] = None, + truncation_length: int = 1022, + toks_per_batch: int = 4096, + return_contacts: bool = False, + repr_layer: int = 36, + *args, + **kwargs, + ): + """ + Initialize the ESM2EmbeddingReader class. + + Args: + save_model_dir (str): Directory to save/load the pretrained ESM model. + model_name (str): Name of the pretrained model. Defaults to "esm2_t36_3B_UR50D". + device (torch.device or str, optional): Device for computation (e.g., 'cpu', 'cuda'). + truncation_length (int): Maximum sequence length for truncation. Defaults to 1022. + toks_per_batch (int): Tokens per batch for data processing. Defaults to 4096. + return_contacts (bool): Whether to return contact maps. Defaults to False. + repr_layers (int): Layer number to extract representations from. Defaults to 36. + """ + self.save_model_dir = save_model_dir + if not os.path.exists(self.save_model_dir): + os.makedirs((os.path.dirname(self.save_model_dir)), exist_ok=True) + self.model_name = model_name + self.device = device + self.truncation_length = truncation_length + self.toks_per_batch = toks_per_batch + self.return_contacts = return_contacts + self.repr_layer = repr_layer + + self._model: Optional[ESM2] = None + self._alphabet: Optional[Alphabet] = None + + self._model, self._alphabet = self.load_model_and_alphabet() + self._model.eval() + + if self.device: + self._model = self._model.to(device) + + super().__init__(*args, **kwargs) + + def load_model_and_alphabet(self) -> Tuple[ESM2, Alphabet]: + """ + Load the ESM2 model and its alphabet. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L24-L28 + + Returns: + Tuple[ESM2, Alphabet]: Loaded model and alphabet. + """ + model_location = os.path.join(self.save_model_dir, f"{self.model_name}.pt") + if os.path.exists(model_location): + return load_model_and_alphabet_local(model_location) + else: + return self.load_model_and_alphabet_hub() + + def load_model_and_alphabet_hub(self) -> Tuple[ESM2, Alphabet]: + """ + Load the model and alphabet from the hub URL. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L62-L64 + + Returns: + Tuple[ESM2, Alphabet]: Loaded model and alphabet. + """ + model_url = self._MODELS_URL.format(self.model_name) + model_data = self.load_hub_workaround(model_url) + regression_data = None + if _has_regression_weights(self.model_name): + regression_url = self._REGRESSION_URL.format(self.model_name) + regression_data = self.load_hub_workaround(regression_url) + return load_model_and_alphabet_core( + self.model_name, model_data, regression_data + ) + + def load_hub_workaround(self, url) -> torch.Tensor: + """ + Workaround to load models from the PyTorch Hub. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L31-L43 + + Returns: + torch.Tensor: Loaded model state dictionary. + """ + try: + data = torch.hub.load_state_dict_from_url( + url, self.save_model_dir, progress=True, map_location=self.device + ) + + except RuntimeError: + # Handle PyTorch version issues + fn = Path(url).name + data = torch.load( + f"{torch.hub.get_dir()}/checkpoints/{fn}", + map_location="cpu", + ) + except HTTPError as e: + raise Exception( + f"Could not load {url}. Did you specify the correct model name?" + ) + return data + + @staticmethod + def name() -> str: + """ + Returns the name of the data reader. This method identifies the specific type of data reader. + + Returns: + str: The name of the data reader, which is "protein_token". + """ + return "esm2_embedding" + + @property + def token_path(self) -> None: + """ + Not used as no token file is not created for this reader. + + Returns: + str: Empty string since this method is not implemented. + """ + return + + def _read_data(self, raw_data: str) -> List[int]: + """ + Reads protein sequence data and generates embeddings. + + Args: + raw_data (str): The protein sequence. + + Returns: + List[int]: Embeddings generated for the sequence. + """ + alp_tokens_idx = self._sequence_to_alphabet_tokens_idx(raw_data) + return self._alphabet_tokens_to_esm_embedding(alp_tokens_idx).tolist() + + def _sequence_to_alphabet_tokens_idx(self, sequence: str) -> torch.Tensor: + """ + Converts a protein sequence into ESM alphabet token indices. + + Args: + sequence (str): Protein sequence. + + References: + https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L249-L250 + https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L262-L297 + + Returns: + torch.Tensor: Tokenized sequence with special tokens (BOS/EOS) included. + """ + seq_encoded = self._alphabet.encode(sequence) + tokens = [] + + # Add BOS token if configured + if self._alphabet.prepend_bos: + tokens.append(self._alphabet.cls_idx) + + # Add the main sequence + tokens.extend(seq_encoded) + + # Add EOS token if configured + if self._alphabet.append_eos: + tokens.append(self._alphabet.eos_idx) + + # Convert to PyTorch tensor and return + return torch.tensor([tokens], dtype=torch.int64) + + def _alphabet_tokens_to_esm_embedding(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Converts alphabet tokens into ESM embeddings. + + Args: + tokens (torch.Tensor): Tokenized protein sequences. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py#L82-L107 + + Returns: + torch.Tensor: Protein embedding from the specified representation layer. + """ + if self.device: + tokens = tokens.to(self.device, non_blocking=True) + + with torch.no_grad(): + out = self._model( + tokens, + repr_layers=[ + self.repr_layer, + ], + return_contacts=self.return_contacts, + ) + + # Extract representations and compute the mean embedding for each layer + representations = { + layer: t.to(self.device) for layer, t in out["representations"].items() + } + truncate_len = min(self.truncation_length, tokens.size(1)) + + result = { + "mean_representations": { + layer: t[0, 1 : truncate_len + 1].mean(0).clone() + for layer, t in representations.items() + } + } + return result["mean_representations"][self.repr_layer] + + def on_finish(self) -> None: + """ + Not used here as no token file exists for this reader. + + Returns: + None + """ + pass diff --git a/chebai/result/evaluate_predictions.py b/chebai/result/evaluate_predictions.py new file mode 100644 index 00000000..355c07c0 --- /dev/null +++ b/chebai/result/evaluate_predictions.py @@ -0,0 +1,108 @@ +from typing import Tuple + +import numpy as np +import torch +from jsonargparse import CLI +from torchmetrics.functional.classification import multilabel_auroc + +from chebai.callbacks.epoch_metrics import MacroF1 +from chebai.result.utils import load_results_from_buffer + + +class EvaluatePredictions: + def __init__(self, eval_dir: str): + """ + Initializes the EvaluatePredictions class. + + Args: + eval_dir (str): Path to the directory containing evaluation files. + """ + self.eval_dir = eval_dir + self.metrics = [] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.num_labels = None + + @staticmethod + def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> None: + """ + Validates that the number of labels matches the number of predictions, + ensuring that they have the same shape. + + Args: + label_files (torch.Tensor): Tensor containing label data. + pred_files (torch.Tensor): Tensor containing prediction data. + + Raises: + ValueError: If label and prediction tensors are mismatched in shape. + """ + if label_files is None or pred_files is None: + raise ValueError("Both label and prediction tensors must be provided.") + + # Check if the number of labels matches the number of predictions + if label_files.shape[0] != pred_files.shape[0]: + raise ValueError( + "Number of label tensors does not match the number of prediction tensors." + ) + + # Validate that the last dimension matches the expected number of classes + if label_files.shape[1] != pred_files.shape[1]: + raise ValueError( + "Label and prediction tensors must have the same shape in terms of class outputs." + ) + + def evaluate(self) -> None: + """ + Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC and Fmax. + """ + test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device) + self.validate_eval_dir(test_labels, test_preds) + self.num_labels = test_preds.shape[1] + + ml_auroc = multilabel_auroc( + test_preds, test_labels, num_labels=self.num_labels + ).item() + + print("Multilabel AUC-ROC:", ml_auroc) + + fmax, threshold = self.calculate_fmax(test_preds, test_labels) + print(f"F-max : {fmax}, threshold: {threshold}") + + def calculate_fmax( + self, test_preds: torch.Tensor, test_labels: torch.Tensor + ) -> Tuple[float, float]: + """ + Calculates the Fmax metric using the F1 score at various thresholds. + + Args: + test_preds (torch.Tensor): Predicted scores for the labels. + test_labels (torch.Tensor): True labels for the evaluation. + + Returns: + Tuple[float, float]: The maximum F1 score and the corresponding threshold. + """ + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/metrics.py#L51-L52 + thresholds = np.linspace(0, 1, 101) + fmax = 0.0 + best_threshold = 0.0 + + for t in thresholds: + custom_f1_metric = MacroF1(num_labels=self.num_labels, threshold=t) + custom_f1_metric.update(test_preds, test_labels) + custom_f1_metric_score = custom_f1_metric.compute().item() + + # Check if the current score is the best we've seen + if custom_f1_metric_score > fmax: + fmax = custom_f1_metric_score + best_threshold = t + + return fmax, best_threshold + + +class Main: + def evaluate(self, eval_dir: str): + EvaluatePredictions(eval_dir).evaluate() + + +if __name__ == "__main__": + # evaluate_predictions.py evaluate + CLI(Main) diff --git a/configs/data/chebi100.yml b/configs/data/chebi/chebi100.yml similarity index 100% rename from configs/data/chebi100.yml rename to configs/data/chebi/chebi100.yml diff --git a/configs/data/chebi100_SELFIES.yml b/configs/data/chebi/chebi100_SELFIES.yml similarity index 100% rename from configs/data/chebi100_SELFIES.yml rename to configs/data/chebi/chebi100_SELFIES.yml diff --git a/configs/data/chebi100_deepSMILES.yml b/configs/data/chebi/chebi100_deepSMILES.yml similarity index 100% rename from configs/data/chebi100_deepSMILES.yml rename to configs/data/chebi/chebi100_deepSMILES.yml diff --git a/configs/data/chebi100_mixed.yml b/configs/data/chebi/chebi100_mixed.yml similarity index 100% rename from configs/data/chebi100_mixed.yml rename to configs/data/chebi/chebi100_mixed.yml diff --git a/configs/data/chebi50.yml b/configs/data/chebi/chebi50.yml similarity index 100% rename from configs/data/chebi50.yml rename to configs/data/chebi/chebi50.yml diff --git a/configs/data/chebi50_mixed.yml b/configs/data/chebi/chebi50_mixed.yml similarity index 100% rename from configs/data/chebi50_mixed.yml rename to configs/data/chebi/chebi50_mixed.yml diff --git a/configs/data/chebi50_partial.yml b/configs/data/chebi/chebi50_partial.yml similarity index 100% rename from configs/data/chebi50_partial.yml rename to configs/data/chebi/chebi50_partial.yml diff --git a/configs/data/deepGO/deepgo2_esm2.yml b/configs/data/deepGO/deepgo2_esm2.yml new file mode 100644 index 00000000..5a0436e3 --- /dev/null +++ b/configs/data/deepGO/deepgo2_esm2.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 + use_esm2_embeddings: True diff --git a/configs/data/deepGO/deepgo_1_migrated_data.yml b/configs/data/deepGO/deepgo_1_migrated_data.yml new file mode 100644 index 00000000..0924e023 --- /dev/null +++ b/configs/data/deepGO/deepgo_1_migrated_data.yml @@ -0,0 +1,4 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO1MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1002 diff --git a/configs/data/deepGO/deepgo_2_migrated_data.yml b/configs/data/deepGO/deepgo_2_migrated_data.yml new file mode 100644 index 00000000..5a0436e3 --- /dev/null +++ b/configs/data/deepGO/deepgo_2_migrated_data.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 + use_esm2_embeddings: True diff --git a/configs/data/deepGO/go250.yml b/configs/data/deepGO/go250.yml new file mode 100644 index 00000000..01e34aa4 --- /dev/null +++ b/configs/data/deepGO/go250.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.go_uniprot.deepGO.GOUniProtOver250 +init_args: + go_branch: "BP" diff --git a/configs/data/deepGO/go50.yml b/configs/data/deepGO/go50.yml new file mode 100644 index 00000000..bee43773 --- /dev/null +++ b/configs/data/deepGO/go50.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.GOUniProtOver50 diff --git a/configs/data/go250.yml b/configs/data/go250.yml deleted file mode 100644 index 5598495c..00000000 --- a/configs/data/go250.yml +++ /dev/null @@ -1,3 +0,0 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver250 -init_args: - go_branch: "BP" diff --git a/configs/data/go50.yml b/configs/data/go50.yml deleted file mode 100644 index 2ed4d14c..00000000 --- a/configs/data/go50.yml +++ /dev/null @@ -1 +0,0 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver50 diff --git a/configs/data/pubchem_SELFIES.yml b/configs/data/pubchem/pubchem_SELFIES.yml similarity index 100% rename from configs/data/pubchem_SELFIES.yml rename to configs/data/pubchem/pubchem_SELFIES.yml diff --git a/configs/data/pubchem_deepSMILES.yml b/configs/data/pubchem/pubchem_deepSMILES.yml similarity index 100% rename from configs/data/pubchem_deepSMILES.yml rename to configs/data/pubchem/pubchem_deepSMILES.yml diff --git a/configs/data/pubchem_dissimilar.yml b/configs/data/pubchem/pubchem_dissimilar.yml similarity index 100% rename from configs/data/pubchem_dissimilar.yml rename to configs/data/pubchem/pubchem_dissimilar.yml diff --git a/configs/data/scope/scope2000.yml b/configs/data/scope/scope2000.yml new file mode 100644 index 00000000..d75c807f --- /dev/null +++ b/configs/data/scope/scope2000.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver2000 +init_args: + scope_version: "2.08" diff --git a/configs/data/scope/scope50.yml b/configs/data/scope/scope50.yml new file mode 100644 index 00000000..c65028e2 --- /dev/null +++ b/configs/data/scope/scope50.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver50 +init_args: + scope_version: "2.08" \ No newline at end of file diff --git a/configs/data/tox21_moleculenet.yml b/configs/data/tox21/tox21_moleculenet.yml similarity index 100% rename from configs/data/tox21_moleculenet.yml rename to configs/data/tox21/tox21_moleculenet.yml diff --git a/configs/model/ffn.yml b/configs/model/ffn.yml new file mode 100644 index 00000000..ba94a43e --- /dev/null +++ b/configs/model/ffn.yml @@ -0,0 +1,5 @@ +class_path: chebai.models.ffn.FFN +init_args: + optimizer_kwargs: + lr: 1e-3 + input_size: 2560 diff --git a/setup.py b/setup.py index 9afca834..8a6d3e0c 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ "pyyaml", "torchmetrics", "biopython", + "fair-esm", ], extras_require={"dev": ["black", "isort", "pre-commit"]}, ) diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py index 9da48bee..96ff9a3a 100644 --- a/tests/unit/dataset_classes/testGOUniProDataExtractor.py +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -6,7 +6,7 @@ import networkx as nx import pandas as pd -from chebai.preprocessing.datasets.go_uniprot import _GOUniProtDataExtractor +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtDataExtractor from chebai.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/dataset_classes/testGoUniProtOverX.py b/tests/unit/dataset_classes/testGoUniProtOverX.py index d4157770..3f329c56 100644 --- a/tests/unit/dataset_classes/testGoUniProtOverX.py +++ b/tests/unit/dataset_classes/testGoUniProtOverX.py @@ -5,7 +5,7 @@ import networkx as nx import pandas as pd -from chebai.preprocessing.datasets.go_uniprot import _GOUniProtOverX +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtOverX from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/dataset_classes/testProteinPretrainingData.py b/tests/unit/dataset_classes/testProteinPretrainingData.py index cb6b0688..caac3eac 100644 --- a/tests/unit/dataset_classes/testProteinPretrainingData.py +++ b/tests/unit/dataset_classes/testProteinPretrainingData.py @@ -1,7 +1,9 @@ import unittest from unittest.mock import PropertyMock, mock_open, patch -from chebai.preprocessing.datasets.protein_pretraining import _ProteinPretrainingData +from chebai.preprocessing.datasets.deepGO.protein_pretraining import ( + _ProteinPretrainingData, +) from chebai.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index a05b89f1..552d2918 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -658,18 +658,19 @@ def get_UniProt_raw_data() -> str: - **Swiss_Prot_1**: A valid protein with three valid GO classes and one invalid GO class. - **Swiss_Prot_2**: Another valid protein with two valid GO classes and one invalid. - **Swiss_Prot_3**: Contains valid GO classes but has a sequence length > 1002. - - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'X'. + - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'B'. - **Swiss_Prot_5**: Has a sequence but no GO classes associated. - **Swiss_Prot_6**: Has GO classes without any associated evidence codes. - **Swiss_Prot_7**: Has a GO class with an invalid evidence code. - **Swiss_Prot_8**: Has a sequence length > 1002 and has only invalid GO class. - - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'X', in its sequence. + - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'B', in its sequence. - **Swiss_Prot_10**: Has a valid GO class but lacks a sequence. - **Swiss_Prot_11**: Has only Invalid GO class but lacks a sequence. Note: - A valid GO label is the one which has one of the following evidence code - (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). + A valid GO label is the one which has one of the following evidence code specified in + go_uniprot.py->`EXPERIMENTAL_EVIDENCE_CODES`. + Invalid amino acids are specified in go_uniprot.py->`AMBIGUOUS_AMINO_ACIDS`. Returns: str: The raw UniProt data in string format. @@ -715,7 +716,7 @@ def get_UniProt_raw_data() -> str: "DR GO; GO:0000005; P:regulation of viral transcription; IEA:InterPro.\n" "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + " BAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" # Below protein with sequence string but has no GO class "ID Swiss_Prot_5 Reviewed; 60 AA.\n" @@ -749,7 +750,7 @@ def get_UniProt_raw_data() -> str: "ID Swiss_Prot_9 Reviewed; 60 AA.\n" "AC Q6GZX4;\n" "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" - " XAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + " BAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" "//\n" # Below protein with a `valid` associated GO class but without sequence string "ID Swiss_Prot_10 Reviewed; 60 AA.\n" diff --git a/tutorials/data_exploration_scope.ipynb b/tutorials/data_exploration_scope.ipynb new file mode 100644 index 00000000..c14046ac --- /dev/null +++ b/tutorials/data_exploration_scope.ipynb @@ -0,0 +1,1182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0bd757ea-a6a0-43f8-8701-cafb44f20f6b", + "metadata": {}, + "source": [ + "# Introduction\n", + "\n", + "This notebook serves as a guide for new developers using the `chebai` package. If you just want to run the experiments, you can refer to the [README.md](https://github.com/ChEB-AI/python-chebai/blob/dev/README.md) and the [wiki](https://github.com/ChEB-AI/python-chebai/wiki) for the basic commands. This notebook explains what happens under the hood for the SCOPe dataset. It covers\n", + "- how to instantiate a data class and generate data\n", + "- how the data is processed and stored\n", + "- and how to work with different molecule encodings.\n", + "\n", + "The `chebai` package simplifies the handling of these datasets by **automatically downloading and processing** them as needed. This means that you do not have to input any data manually; the package will generate and organize the data files based on the parameters and encodings selected. You can however provide your own data files, for instance if you want to replicate a specific experiment.\n", + "\n", + "---\n" + ] + }, + { + "cell_type": "markdown", + "id": "cca637ce-d4ea-4365-acd9-657418e0640f", + "metadata": {}, + "source": [ + "### Overview of SCOPe Data and its Usage in Protein-Related Tasks\n", + "\n", + "#### **What is SCOPe?**\n", + "\n", + "The **Structural Classification of Proteins — extended (SCOPe)** is a comprehensive database that extends the original SCOP (Structural Classification of Proteins) database. SCOPe offers a detailed classification of protein domains based on their structural and evolutionary relationships.\n", + "\n", + "The SCOPe database, like SCOP, organizes proteins into a hierarchy of domains based on structural similarities, which is crucial for understanding evolutionary patterns and functional aspects of proteins. This hierarchical structure is comparable to taxonomy in biology, where species are classified based on shared characteristics.\n", + "\n", + "#### **SCOPe Hierarchy:**\n", + "By analogy with taxonomy, SCOP was created as a hierarchy of several levels where the fundamental unit of classification is a **domain** in the experimentally determined protein structure. Starting at the bottom, the hierarchy of SCOP domains comprises the following levels:\n", + "\n", + "1. **Species**: Representing distinct protein sequences and their naturally occurring or artificially created variants.\n", + "2. **Protein**: Groups together similar sequences with essentially the same functions. These can originate from different biological species or represent isoforms within the same species.\n", + "3. **Family**: Contains proteins with similar sequences but typically distinct functions.\n", + "4. **Superfamily**: Bridges protein families with common functional and structural features, often inferred from a shared evolutionary ancestor.\n", + "5. **Fold**: Groups structurally similar superfamilies. \n", + "6. **Class**: Based on secondary structure content and organization. This level classifies proteins based on their secondary structure properties, such as alpha-helices and beta-sheets.\n", + "\n", + "\n", + "\n", + "For more details, you can refer to the [SCOPe documentation](https://scop.berkeley.edu/help/ver=2.08).\n", + "\n", + "---\n", + "\n", + "#### **Why are We Using SCOPe?**\n", + "\n", + "We are integrating the SCOPe data into our pipeline as part of an ontology pretraining task for protein-related models. SCOPe is a great fit for our goal because it is primarily **structure-based**, unlike other protein-related databases like Gene Ontology (GO), which focuses more on functional classes.\n", + "\n", + "Our primary objective is to reproduce **ontology pretraining** on a protein-related task, and SCOPe provides the structural ontology that we need for this. The steps in our pipeline are aligned as follows:\n", + "\n", + "| **Stage** | **Chemistry Task** | **Proteins Task** |\n", + "|--------------------------|-------------------------------------|------------------------------------------------|\n", + "| **Unsupervised Pretraining** | Mask pretraining (ELECTRA) | Mask pretraining (ESM2, optional) |\n", + "| **Ontology Pretraining** | ChEBI | SCOPe |\n", + "| **Finetuning Task** | Toxicity, Solubility, etc. | GO (MF, BP, CC branches) |\n", + "\n", + " \n", + "This integration will allow us to use **SCOPe** for tasks such as **protein classification** and will contribute to the success of **pretraining models** for protein structures. The data will be processed with the same approach as the GO data, with **different labels** corresponding to the SCOPe classification system.\n", + "\n", + "---\n", + "\n", + "#### **Why SCOPe is Suitable for Our Task**\n", + "\n", + "1. **Structure-Based Classification**: SCOPe is primarily concerned with the structural characteristics of proteins, making it ideal for protein structure pretraining tasks. This contrasts with other ontology databases like **GO**, which categorize proteins based on more complex functional relationships.\n", + " \n", + "2. **Manageable Size**: SCOPe contains around **140,000 entries**, making it a manageable dataset for training models. This is similar in size to **ChEBI**, which is used in the chemical domain, and ensures we can work with it effectively for pretraining." + ] + }, + { + "cell_type": "markdown", + "id": "338e452f-426c-493d-bec2-5bd51e24e4aa", + "metadata": {}, + "source": [ + "\n", + "### Protein Data Bank (PDB)\n", + "\n", + "The **Protein Data Bank (PDB)** is a global repository that stores 3D structural data of biological macromolecules like proteins and nucleic acids. It contains information obtained through experimental methods such as **X-ray crystallography**, **NMR spectroscopy**, and **cryo-EM**. The data includes atomic coordinates, secondary structure details, and experimental conditions.\n", + "\n", + "The PDB is an essential resource for **structural biology**, **bioinformatics**, and **drug discovery**, enabling scientists to understand protein functions, interactions, and mechanisms at the molecular level.\n", + "\n", + "For more details, visit the [RCSB PDB website](https://www.rcsb.org/).\n" + ] + }, + { + "cell_type": "markdown", + "id": "f6c25706-251c-438c-9915-e8002647eb94", + "metadata": {}, + "source": [ + "### Understanding [SCOPe](https://scop.berkeley.edu/) and [PDB](https://www.rcsb.org/) \n", + "\n", + "\n", + "1. **Protein domains form chains.** \n", + "2. **Chains form complexes** (protein complexes or structures). \n", + "3. These **complexes are the entries in PDB**, represented by unique identifiers like `\"1A3N\"`. \n", + "\n", + "---\n", + "\n", + "#### **Protein Domain** \n", + "A **protein domain** is a **structural and functional unit** of a protein. \n", + "\n", + "\n", + "##### Key Characteristics:\n", + "- **Domains are part of a protein chain.** \n", + "- A domain can span: \n", + " 1. **The entire chain** (single-domain protein): \n", + " - In this case, the protein domain is equivalent to the chain itself. \n", + " - Example: \n", + " - All chains of the **PDB structure \"1A3N\"** are single-domain proteins. \n", + " - Each chain has a SCOPe domain identifier. \n", + " - For example, Chain **A**: \n", + " - Domain identifier: `d1a3na_` \n", + " - Breakdown of the identifier: \n", + " - `d`: Denotes domain. \n", + " - `1a3n`: Refers to the PDB protein structure identifier. \n", + " - `a`: Specifies the chain within the structure. (`_` for None and `.` for multiple chains)\n", + " - `_`: Indicates the domain spans the entire chain (single-domain protein). \n", + " - Example: [PDB Structure 1A3N - Chain A](https://www.rcsb.org/sequence/1A3N#A)\n", + " 2. **A specific portion of the chain** (multi-domain protein): \n", + " - Here, a single chain contains multiple domains. \n", + " - Example: Chain **A** of the **PDB structure \"1PKN\"** contains three domains: `d1pkna1`, `d1pkna2`, `d1pkna3`. \n", + " - Example: [PDB Structure 1PKN - Chain A](https://www.rcsb.org/annotations/1PKN). \n", + "\n", + "---\n", + "\n", + "#### **Protein Chain** \n", + "A **protein chain** refers to the entire **polypeptide chain** observed in a protein's 3D structure (as described in PDB files). \n", + "\n", + "##### Key Points:\n", + "- A chain can consist of **one or multiple domains**:\n", + " - **Single-domain chain**: The chain and domain are identical. \n", + " - Example: Myoglobin. \n", + " - **Multi-domain chain**: Contains several domains, each with distinct structural and functional roles. \n", + "- Chains assemble to form **protein complexes** or **structures**. \n", + "\n", + "\n", + "---\n", + "\n", + "#### **Key Observations About SCOPe** \n", + "- The **fundamental classification unit** in SCOPe is the **protein domain**, not the entire protein. \n", + "- _**The taxonomy in SCOPe is not for the entire protein (i.e., the full-length amino acid sequence as encoded by a gene) but for protein domains, which are smaller, structurally and functionally distinct regions of the protein.**_\n", + "\n", + "\n", + "--- \n", + "\n", + "**SCOPe 2.08 Data Analysis:**\n", + "\n", + "The current SCOPe version (2.08) includes the following statistics based on analysis for relevant data:\n", + "\n", + "- **Classes**: 12\n", + "- **Folds**: 1485\n", + "- **Superfamilies**: 2368\n", + "- **Families**: 5431\n", + "- **Proteins**: 13,514\n", + "- **Species**: 30,294\n", + "- **Domains**: 344,851\n", + "\n", + "For more detailed statistics, please refer to the official SCOPe website:\n", + "\n", + "- [SCOPe 2.08 Statistics](https://scop.berkeley.edu/statistics/ver=2.08)\n", + "- [SCOPe 2.08 Release](https://scop.berkeley.edu/ver=2.08)\n", + "\n", + "---\n", + "\n", + "## SCOPe Labeling \n", + "\n", + "- Use SCOPe labels for protein domains.\n", + "- Map them back to their **protein-chain** sequences (protein sequence label = sum of all domain labels).\n", + "- Train on protein sequences.\n", + "- This pretraining task would be comparable to GO-based training.\n", + "\n", + "--- " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "990cc6f2-6b4a-4fa7-905f-dda183c3ec4c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Changed to project root directory: G:\\github-aditya0by0\\python-chebai\n" + ] + } + ], + "source": [ + "# To run this notebook, you need to change the working directory of the jupyter notebook to root dir of the project.\n", + "import os\n", + "\n", + "# Root directory name of the project\n", + "expected_root_dir = \"python-chebai\"\n", + "\n", + "# Check if the current directory ends with the expected root directory name\n", + "if not os.getcwd().endswith(expected_root_dir):\n", + " os.chdir(\"..\") # Move up one directory level\n", + " if os.getcwd().endswith(expected_root_dir):\n", + " print(\"Changed to project root directory:\", os.getcwd())\n", + " else:\n", + " print(\"Warning: Directory change unsuccessful. Current directory:\", os.getcwd())\n", + "else:\n", + " print(\"Already in the project root directory:\", os.getcwd())" + ] + }, + { + "cell_type": "markdown", + "id": "4550d01fc7af5ae4", + "metadata": {}, + "source": [ + "# 1. Instantiation of a Data Class\n", + "\n", + "To start working with `chebai`, you first need to instantiate a SCOPe data class. This class is responsible for managing, interacting with, and preprocessing the ChEBI chemical data." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f3a66e07-edc9-4aa2-9cd0-d4ea58914d22", + "metadata": {}, + "outputs": [], + "source": [ + "from chebai.preprocessing.datasets.scope.scope import SCOPeOver50" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a71b7301-6195-4155-a439-f5eb3183d0f3", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:07:26.371796Z", + "start_time": "2024-10-05T21:07:26.058728Z" + } + }, + "outputs": [], + "source": [ + "scope_class = SCOPeOver50(scope_version=\"2.08\")" + ] + }, + { + "cell_type": "markdown", + "id": "b810d7c9-4f7f-4725-9bc2-452ff2c3a89d", + "metadata": {}, + "source": [ + "\n", + "### Inheritance Hierarchy\n", + "\n", + "SCOPe data classes inherit from [`_DynamicDataset`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L598), which in turn inherits from [`XYBaseDataModule`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L23). Specifically:\n", + "\n", + "- **`_DynamicDataset`**: This class serves as an intermediate base class that provides additional functionality or customization for datasets that require dynamic behavior. It inherits from `XYBaseDataModule`, which provides the core methods for data loading and processing.\n", + "\n", + "- **`XYBaseDataModule`**: This is the base class for data modules, providing foundational properties and methods for handling and processing datasets, including data splitting, loading, and preprocessing.\n", + "\n", + "In summary, ChEBI data classes are designed to manage and preprocess chemical data effectively by leveraging the capabilities provided by `XYBaseDataModule` through the `_DynamicDataset` intermediary.\n", + "\n", + "\n", + "### Input parameters\n", + "A SCOPe data class can be configured with a range of parameters, including:\n", + "\n", + "- **scope_version (str)**: Specifies the version of the ChEBI database to be used. Specifying a version ensures the reproducibility of your experiments by using a consistent dataset.\n", + "\n", + "- **scope_version_train (str, optional)**: The version of ChEBI to use specifically for training and validation. If not set, the `scope_version` specified will be used for all data splits, including training, validation, and test. Defaults to `None`.\n", + "\n", + "- **splits_file_path (str, optional)**: Path to a CSV file containing data splits. If not provided, the class will handle splits internally. Defaults to `None`.\n", + "\n", + "### Additional Input Parameters\n", + "\n", + "To get more control over various aspects of data loading, processing, and splitting, you can refer to documentation of additional parameters in docstrings of the respective classes: [`_SCOPeDataExtractor`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/scope/scope.py#L31), [`XYBaseDataModule`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L22), [`_DynamicDataset`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L597), etc.\n" + ] + }, + { + "cell_type": "markdown", + "id": "8578b7aa-1bd9-4e50-9eee-01bfc6d5464a", + "metadata": {}, + "source": [ + "# Available SCOPe Data Classes\n", + "\n", + "__Note__: Check the code implementation of classes [here](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/scope/scope.py):\n", + "\n", + "There is a range of available dataset classes for SCOPe. Usually, you want to use `SCOPeOver2000` or `SCOPeOver50`. The number indicates the threshold for selecting label classes: SCOPe classes which have at least 2000 / 50 subclasses will be used as labels.\n", + "\n", + "Both inherit from `SCOPeOverX`. If you need a different threshold, you can create your own subclass. By default, `SCOPeOverX` uses the Protein encoding (see Section 5).\n", + "\n", + "Finally, `SCOPeOver2000Partial` selects extracts a part of SCOPe based on a given top class, with a threshold of 2000 for selecting labels.\n", + "This class inherits from `SCOPEOverXPartial`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "8456b545-88c5-401d-baa5-47e8ae710f04", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ed973fb59df11849", + "metadata": {}, + "source": [ + "# 2. Preparation / Setup Methods\n", + "\n", + "Now we have a SCOPe data class with all the relevant parameters. Next, we need to generate the actual dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "11f2208e-fa40-44c9-bfe7-576ca23ad366", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checking for processed data in data\\SCOPe\\version_2.08\\SCOPe50\\processed\n", + "Missing processed data file (`data.pkl` file)\n", + "Missing PDB raw data, Downloading PDB sequence data....\n", + "Downloading to temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Downloaded to C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Unzipping the file....\n", + "Unpacked and saved to data\\SCOPe\\pdb_sequences.txt\n", + "Removed temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Missing Scope: cla.txt raw data, Downloading...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "G:\\anaconda3\\envs\\env_chebai\\lib\\site-packages\\urllib3\\connectionpool.py:1099: InsecureRequestWarning: Unverified HTTPS request is being made to host 'scop.berkeley.edu'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n", + "warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing Scope: hie.txt raw data, Downloading...\n", + "Missing Scope: des.txt raw data, Downloading...\n", + "Extracting class hierarchy...\n", + "Computing transitive closure\n", + "Process graph\n", + "101 labels has been selected for specified threshold, \n", + "Constructing data.pkl file .....\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Check for processed data in data\\SCOPe\\version_2.08\\SCOPe50\\processed\\protein_token\n", + "Cross-validation enabled: False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing transformed data (`data.pt` file). Transforming data.... \n", + "Processing 60298 lines...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████| 60298/60298 [00:53<00:00, 1119.10it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving 21 tokens to G:\\github-aditya0by0\\python-chebai\\chebai\\preprocessing\\bin\\protein_token\\tokens.txt...\n", + "First 10 tokens: ['M', 'S', 'I', 'G', 'A', 'T', 'R', 'L', 'Q', 'N']\n" + ] + } + ], + "source": [ + "scope_class.prepare_data()\n", + "scope_class.setup()" + ] + }, + { + "cell_type": "markdown", + "id": "1655d489-25fe-46de-9feb-eeca5d36936f", + "metadata": {}, + "source": [ + "\n", + "### Automatic Execution: \n", + "These methods are executed automatically when using the training command `chebai fit`. Users do not need to call them explicitly, as the code internally manages the preparation and setup of data, ensuring that it is ready for subsequent use in training and validation processes.\n", + "\n", + "### Why is Preparation Needed?\n", + "\n", + "- **Data Availability**: The preparation step ensures that the required SCOPe data files are downloaded or loaded, which are essential for analysis.\n", + "- **Data Integrity**: It ensures that the data files are transformed into a compatible format required for model input.\n", + "\n", + "### Main Methods for Data Preprocessing\n", + "\n", + "The data preprocessing in a data class involves two main methods:\n", + "\n", + "1. **`prepare_data` Method**:\n", + " - **Purpose**: This method checks for the presence of raw data in the specified directory. If the raw data is missing, it fetches the ontology, creates a dataframe, and saves it to a file (`data.pkl`). The dataframe includes columns such as IDs, data representations, and labels. This step is independent of input encodings.\n", + " - **Documentation**: [PyTorch Lightning - `prepare_data`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data)\n", + "\n", + "2. **`setup` Method**:\n", + " - **Purpose**: This method sets up the data module for training, validation, and testing. It checks for the processed data and, if necessary, performs additional setup to ensure the data is ready for model input. It also handles cross-validation settings if enabled.\n", + " - **Description**: Transforms `data.pkl` into a model input data format (`data.pt`), tokenizing the input according to the specified encoding. The transformed data contains the following keys: `ident`, `features`, `labels`, and `group`. This method uses a subclass of Data Reader to perform the tokenization.\n", + " - **Documentation**: [PyTorch Lightning - `setup`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#setup)\n", + "\n", + "These methods ensure that the data is correctly prepared and set up for subsequent use in training and validation processes." + ] + }, + { + "cell_type": "markdown", + "id": "f5aaa12d-5f01-4b74-8b59-72562af953bf", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "bb6e9a81554368f7", + "metadata": {}, + "source": [ + "# 3. Overview of the 3 preprocessing stages\n", + "\n", + "The `chebai` library follows a three-stage preprocessing pipeline, which is reflected in its file structure:\n", + "\n", + "1. **Raw Data Stage**:\n", + " - **Files**: `cla.txt`, `des.txt` and `hie.txt`. Please find description of each file [here](https://scop.berkeley.edu/help/ver=2.08#parseablefiles-2.08).\n", + " - **Description**: This stage contains the raw SCOPe data in txt format, serving as the initial input for further processing.\n", + " - **File Path**: `data/SCOPe/version_${scope_version}/raw/${filename}.txt`\n", + "\n", + "2. **Processed Data Stage 1**:\n", + " - **File**: `data.pkl`\n", + " - **Description**: This stage includes the data after initial processing. It contains protein sequence strings, class columns, and metadata but lacks data splits.\n", + " - **File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/data.pkl`\n", + " - **Additional File**: `classes.txt` - A file listing the relevant SCOPe classes.\n", + "\n", + "3. **Processed Data Stage 2**:\n", + " - **File**: `data.pt`\n", + " - **Description**: This final stage includes the encoded data in a format compatible with PyTorch, ready for model input. This stage also references data splits when available.\n", + " - **File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/${reader_name}/data.pt`\n", + " - **Additional File**: `splits.csv` - Contains saved splits for reproducibility.\n", + "\n", + "This structured approach to data management ensures that each stage of data processing is well-organized and documented, from raw data acquisition to the preparation of model-ready inputs. It also facilitates reproducibility and traceability across different experiments.\n", + "\n", + "### Data Splits\n", + "\n", + "- **Creation**: Data splits are generated dynamically \"on the fly\" during training and evaluation to ensure flexibility and adaptability to different tasks.\n", + "- **Reproducibility**: To maintain consistency across different runs, splits can be reproduced by comparing hashes with a fixed seed value.\n" + ] + }, + { + "cell_type": "markdown", + "id": "7e172c0d1e8bb93f", + "metadata": {}, + "source": [ + "# 4. Data Files and their structure\n", + "\n", + "`chebai` creates and manages several data files during its operation. These files store various chemical data and metadata essential for different tasks. Let’s explore these files and their content.\n" + ] + }, + { + "cell_type": "markdown", + "id": "43329709-5134-4ce5-88e7-edd2176bf84d", + "metadata": {}, + "source": [ + "## raw files\n", + "- cla.txt, des.txt and hie.txt\n", + "\n", + "For detailed description of raw files and their structures, please refer the official website [here](https://scop.berkeley.edu/help/ver=2.08#parseablefiles-2.08).\n" + ] + }, + { + "cell_type": "markdown", + "id": "558295e5a7ded456", + "metadata": {}, + "source": [ + "## data.pkl File\n", + "\n", + "**Description**: Generated by the `prepare_data` method, this file contains processed data in a dataframe format. It includes the ids, sids which are used to label corresponding sequence, protein-chain sequence, and columns for each label with boolean values." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fd490270-59b8-4c1c-8b09-204defddf592", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:09:01.622317Z", + "start_time": "2024-10-05T21:09:01.606698Z" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d7d16247-092c-4e8d-96c2-ab23931cf766", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:11:51.296162Z", + "start_time": "2024-10-05T21:11:44.559304Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Size of the data (rows x columns): (60424, 1035)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idsidssequenceclass_46456class_48724class_51349class_53931class_56572class_56835class_56992...species_187294species_56257species_186882species_56690species_161316species_57962species_58067species_267696species_311502species_311501
01[d4oq9a_, d4oq9b_, d4oq9c_, d4oq9d_, d4niaa_, ...AAAAAAAAAAFalseTrueFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
12[d7dxhc_]AAAAAAAAAAAAAAAAAAAAAAAFalseFalseFalseFalseFalseTrueFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
23[d1gkub1, d1gkub2, d1gkub3, d1gkub4]AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASLCLFPEDFLLKEF...FalseFalseTrueFalseTrueFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseTrue
34[d3c9wa2, d3c9wb2, d3c9wa3, d3c9wb3]AAAAAAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNLNKV...FalseFalseFalseTrueFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseTrue
45[d1xwaa1, d1xwab_, d1xwac_, d1xwad_, d1xwaa2]AAAAAMVYQVKDKADLDGQLTKASGKLVVLDFFATWCGPCKMISPK...FalseFalseTrueFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseTrue
\n", + "

5 rows × 1035 columns

\n", + "
" + ], + "text/plain": [ + " id sids \\\n", + "0 1 [d4oq9a_, d4oq9b_, d4oq9c_, d4oq9d_, d4niaa_, ... \n", + "1 2 [d7dxhc_] \n", + "2 3 [d1gkub1, d1gkub2, d1gkub3, d1gkub4] \n", + "3 4 [d3c9wa2, d3c9wb2, d3c9wa3, d3c9wb3] \n", + "4 5 [d1xwaa1, d1xwab_, d1xwac_, d1xwad_, d1xwaa2] \n", + "\n", + " sequence class_46456 \\\n", + "0 AAAAAAAAAA False \n", + "1 AAAAAAAAAAAAAAAAAAAAAAA False \n", + "2 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASLCLFPEDFLLKEF... False \n", + "3 AAAAAAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNLNKV... False \n", + "4 AAAAAMVYQVKDKADLDGQLTKASGKLVVLDFFATWCGPCKMISPK... False \n", + "\n", + " class_48724 class_51349 class_53931 class_56572 class_56835 \\\n", + "0 True False False False False \n", + "1 False False False False True \n", + "2 False True False True False \n", + "3 False False True False False \n", + "4 False True False False False \n", + "\n", + " class_56992 ... species_187294 species_56257 species_186882 \\\n", + "0 False ... False False False \n", + "1 False ... False False False \n", + "2 False ... False False False \n", + "3 False ... False False False \n", + "4 False ... False False False \n", + "\n", + " species_56690 species_161316 species_57962 species_58067 \\\n", + "0 False False False False \n", + "1 False False False False \n", + "2 False False False False \n", + "3 False False False False \n", + "4 False False False False \n", + "\n", + " species_267696 species_311502 species_311501 \n", + "0 False False False \n", + "1 False False False \n", + "2 False False True \n", + "3 False False True \n", + "4 False False True \n", + "\n", + "[5 rows x 1035 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pkl_df = pd.DataFrame(\n", + " pd.read_pickle(\n", + " os.path.join(\n", + " scope_class.processed_dir_main,\n", + " scope_class.processed_main_file_names_dict[\"data\"],\n", + " )\n", + " )\n", + ")\n", + "print(\"Size of the data (rows x columns): \", pkl_df.shape)\n", + "pkl_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "322bc926-69ff-4b93-9e95-5e8b85869c38", + "metadata": {}, + "source": [ + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/data.pkl`\n", + "\n", + "\n", + "### Structure of `data.pkl`\n", + "`data.pkl` as following structure: \n", + "- **Column 0**: Contains the ID of eachdata instance.\n", + "- **Column 1**: Contains the `sids` which are associated with corresponding protein-chain sequence.\n", + "- **Column 2**: Contains the protein-chain sequence.\n", + "- **Column 3 and onwards**: Contains the labels, starting from column 3.\n", + "\n", + "This structure ensures that the data is organized and ready for further processing, such as further encoding.\n" + ] + }, + { + "cell_type": "markdown", + "id": "ba019d2d4324bd0b", + "metadata": {}, + "source": [ + "## data.pt File\n", + "\n", + "\n", + "**Description**: Generated by the `setup` method, this file contains encoded data in a format compatible with the PyTorch library, specifically as a list of dictionaries. Each dictionary in this list includes keys such as `ident`, `features`, `labels`, and `group`, ready for model input." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "977ddd83-b469-4b58-ab1a-8574fb8769b4", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:12:49.338943Z", + "start_time": "2024-10-05T21:12:49.323319Z" + } + }, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3266ade9-efdc-49fe-ae07-ed52b2eb52d0", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:14:12.892845Z", + "start_time": "2024-10-05T21:13:59.859953Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of loaded data: \n" + ] + } + ], + "source": [ + "data_pt = torch.load(\n", + " os.path.join(\n", + " scope_class.processed_dir, scope_class.processed_file_names_dict[\"data\"]\n", + " ),\n", + " weights_only=False,\n", + ")\n", + "print(\"Type of loaded data:\", type(data_pt))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "84cfa3e6-f60d-47c0-9f82-db3d5673d1e7", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:14:21.185027Z", + "start_time": "2024-10-05T21:14:21.169358Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'features': [14, 14, 14, 14, 20, 15, 15, 28, 15, 18, 25, 17, 18, 11, 25, 21, 27, 19, 14, 27, 19, 13, 14, 17, 16, 21, 25, 22, 27, 28, 12, 10, 20, 19, 13, 13, 14, 28, 17, 20, 20, 12, 19, 11, 17, 15, 27, 28, 15, 12, 17, 14, 23, 11, 19, 27, 14, 26, 19, 11, 11, 19, 12, 19, 19, 28, 17, 16, 20, 16, 19, 21, 10, 16, 18, 12, 17, 19, 10, 29, 12, 12, 21, 20, 16, 17, 19, 28, 20, 21, 12, 16, 18, 21, 19, 14, 19, 17, 12, 14, 18, 28, 23, 15, 28, 19, 19, 19, 15, 25, 17, 22, 25, 19, 28, 16, 13, 27, 13, 11, 20, 15, 28, 12, 15, 28, 27, 13, 13, 13, 28, 19, 14, 15, 28, 12, 18, 14, 20, 28, 14, 18, 15, 19, 13, 22, 28, 29, 12, 12, 20, 29, 28, 17, 13, 28, 23, 22, 15, 15, 28, 17, 13, 21, 17, 27, 11, 20, 23, 10, 10, 11, 20, 15, 22, 21, 10, 13, 21, 25, 11, 29, 25, 19, 20, 18, 17, 19, 19, 15, 18, 16, 16, 25, 15, 22, 25, 28, 23, 16, 20, 21, 13, 26, 18, 21, 15, 27, 17, 20, 22, 23, 11, 14, 29, 21, 21, 17, 25, 10, 14, 20, 25, 11, 22, 29, 11, 21, 11, 12, 17, 27, 16, 29, 17, 14, 12, 11, 20, 21, 27, 22, 15, 10, 21, 20, 17, 28, 21, 25, 11, 18, 27, 11, 13, 11, 28, 12, 17, 23, 15, 25, 16, 20, 11, 17, 11, 12, 16, 28, 27, 27, 27, 14, 13, 16, 22, 28, 12, 12, 26, 19, 22, 21, 21, 12, 19, 28, 22, 16, 23, 20, 28, 27, 24, 15, 19, 13, 12, 12, 29, 28, 12, 20, 22, 23, 17, 17, 27, 27, 21, 20, 28, 28, 28, 14, 13, 13, 11, 14, 14, 14, 14, 14], 'labels': array([False, True, False, ..., False, False, False]), 'ident': 6, 'group': None}\n" + ] + } + ], + "source": [ + "for i in range(5, 6):\n", + " print(data_pt[i])" + ] + }, + { + "cell_type": "markdown", + "id": "0d80ffbb-5f1e-4489-9bc8-d688c9be1d07", + "metadata": {}, + "source": [ + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/${reader_name}/data.pt`\n", + "\n", + "\n", + "### Structure of `data.pt`\n", + "\n", + "The `data.pt` file is a list where each element is a dictionary with the following keys:\n", + "\n", + "- **`features`**: \n", + " - **Description**: This key holds the input features for the model. The features are typically stored as tensors and represent the attributes used by the model for training and evaluation.\n", + "\n", + "- **`labels`**: \n", + " - **Description**: This key contains the labels or target values associated with each instance. Labels are also stored as tensors and are used by the model to learn and make predictions.\n", + "\n", + "- **`ident`**: \n", + " - **Description**: This key holds identifiers for each data instance. These identifiers help track and reference the individual samples in the dataset.\n" + ] + }, + { + "cell_type": "markdown", + "id": "186ec6f0eed6ecf7", + "metadata": {}, + "source": [ + "## classes.txt File\n", + "\n", + "**Description**: A file containing the list of selected SCOPe **labels** based on the specified threshold. This file is crucial for ensuring that only relevant **labels** are included in the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8d1fbe6c-beb8-4038-93d4-c56bc7628716", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:15:19.146285Z", + "start_time": "2024-10-05T21:15:18.503284Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class_48724\n", + "class_53931\n", + "class_310555\n", + "fold_48725\n", + "fold_56111\n", + "fold_56234\n", + "fold_310573\n", + "superfamily_48726\n", + "superfamily_56112\n", + "superfamily_56235\n", + "superfamily_310607\n", + "family_48942\n", + "family_56251\n", + "family_191359\n", + "family_191470\n" + ] + } + ], + "source": [ + "with open(os.path.join(scope_class.processed_dir_main, \"classes.txt\"), \"r\") as file:\n", + " for i in range(15):\n", + " line = file.readline()\n", + " print(line.strip())" + ] + }, + { + "cell_type": "markdown", + "id": "861da1c3-0401-49f0-a22f-109814ed95d5", + "metadata": {}, + "source": [ + "\n", + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/classes.txt`\n", + "\n", + "The `classes.txt` file lists selected SCOPe classes. These classes are chosen based on a specified threshold, which is typically used for filtering or categorizing the dataset. Each line in the file corresponds to a unique SCOPe class ID, identifying specific class withing SCOPe ontology along with the hierarchy level.\n", + "\n", + "This file is essential for organizing the data and ensuring that only relevant classes, as defined by the threshold, are included in subsequent processing and analysis tasks.\n" + ] + }, + { + "cell_type": "markdown", + "id": "fb72be449e52b63f", + "metadata": {}, + "source": [ + "## splits.csv File\n", + "\n", + "**Description**: Contains saved data splits from previous runs. During subsequent runs, this file is used to reconstruct the train, validation, and test splits by filtering the encoded data (`data.pt`) based on the IDs stored in `splits.csv`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3ebdcae4-4344-46bd-8fc0-a82ef5d40da5", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:15:54.575116Z", + "start_time": "2024-10-05T21:15:53.945139Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idsplit
01train
13train
24train
36train
49train
\n", + "
" + ], + "text/plain": [ + " id split\n", + "0 1 train\n", + "1 3 train\n", + "2 4 train\n", + "3 6 train\n", + "4 9 train" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "csv_df = pd.read_csv(os.path.join(scope_class.processed_dir_main, \"splits.csv\"))\n", + "csv_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "b058714f-e434-4367-89b9-74c129ac727f", + "metadata": {}, + "source": [ + "\n", + "\n", + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/splits.csv`\n", + "\n", + "The `splits.csv` file contains the saved data splits from previous runs, including the train, validation, and test sets. During subsequent runs, this file is used to reconstruct these splits by filtering the encoded data (`data.pt`) based on the IDs stored in `splits.csv`. This ensures consistency and reproducibility in data splitting, allowing for reliable evaluation and comparison of model performance across different run.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6dc3fd6c-7cf6-47ef-812f-54319a0cdeb9", + "metadata": {}, + "outputs": [], + "source": [ + "# You can specify a literal path for the `splits_file_path`, or if another `scope_class` instance is already defined,\n", + "# you can use its existing `splits_file_path` attribute for consistency.\n", + "scope_class_with_splits = SCOPeOver2000(\n", + " scope_version=\"2.08\",\n", + " # splits_file_path=\"data/chebi_v231/ChEBI50/processed/splits.csv\", # Literal path option\n", + " splits_file_path=scope_class.splits_file_path, # Use path from an existing `chebi_class` instance\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a5eb482c-ce5b-4efc-b2ec-85ac7b1a78ee", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ab110764-216d-4d52-a9d1-4412c8ac8c9d", + "metadata": {}, + "source": [ + "## 5.1 Protein Representation Using Amino Acid Sequence Notation\n", + "\n", + "Proteins are composed of chains of amino acids, and these sequences can be represented using a one-letter notation for each amino acid. This notation provides a concise way to describe the primary structure of a protein.\n", + "\n", + "### Example Protein Sequence\n", + "\n", + "Protein-Chain: PDB ID:**1cph** Chain ID:**B** mol:protein length:30 INSULIN (PH 10)\n", + "
Refer - [1cph_B](https://www.rcsb.org/sequence/1CPH)\n", + "\n", + "- **Sequence**: `FVNQHLCGSHLVEALYLVCGERGFFYTPKA`\n", + "- **Sequence Length**: 30\n", + "\n", + "In this sequence, each letter corresponds to a specific amino acid. This notation is widely used in bioinformatics and molecular biology to represent protein sequences.\n", + "\n", + "### Tokenization and Encoding\n", + "\n", + "To tokenize and numerically encode this protein sequence, the `ProteinDataReader` class is used. This class allows for n-gram tokenization, where the `n_gram` parameter defines the size of the tokenized units. If `n_gram` is not provided (default is `None`), each amino acid letter is treated as a single token.\n", + "\n", + "For more details, you can explore the implementation of the `ProteinDataReader` class in the source code [here](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/reader.py)." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "da47d47e-4560-46af-b246-235596f27d82", + "metadata": {}, + "outputs": [], + "source": [ + "from chebai.preprocessing.reader import ProteinDataReader" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8bdbf309-29ec-4aab-a6dc-9e09bc6961a2", + "metadata": {}, + "outputs": [], + "source": [ + "protein_dr_3gram = ProteinDataReader(n_gram=3)\n", + "protein_dr = ProteinDataReader()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "68e5c87c-79c3-4d5f-91e6-635399a84d3d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[25, 28, 19, 18, 29, 17, 24, 13, 11, 29, 17, 28, 27, 14, 17, 22, 17, 28, 24, 13, 27, 16, 13, 25, 25, 22, 15, 23, 21, 14]\n", + "[5023, 2218, 3799, 2290, 6139, 2208, 6917, 4674, 484, 439, 2737, 851, 365, 2624, 3240, 4655, 1904, 3737, 1453, 2659, 5160, 3027, 2355, 7163, 4328, 3115, 6207, 1234]\n" + ] + } + ], + "source": [ + "protein = \"FVNQHLCGSHLVEALYLVCGERGFFYTPKA\"\n", + "print(protein_dr._read_data(protein))\n", + "print(protein_dr_3gram._read_data(protein))" + ] + }, + { + "cell_type": "markdown", + "id": "5b7211ee-2ccc-46d3-8e8f-790f344726ba", + "metadata": {}, + "source": [ + "The numbers mentioned above refer to the index of each individual token from the [`tokens.txt`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/bin/protein_token/tokens.txt) file, which is used by the `ProteinDataReader` class. \n", + "\n", + "Each token in the `tokens.txt` file corresponds to a specific amino-acid letter, and these tokens are referenced by their index. Additionally, the index values are offset by the `EMBEDDING_OFFSET`, ensuring that the token embeddings are adjusted appropriately during processing." + ] + }, + { + "cell_type": "markdown", + "id": "93e328cf-09f9-4694-b175-28320590937d", + "metadata": {}, + "source": [ + "---" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}