From 876d94614676b69627a8588aaf6263f4fae30b71 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 16 Jul 2024 00:22:37 +0200 Subject: [PATCH 01/30] basic data processing for go-uniprot dataset --- chebai/preprocessing/datasets/go_uniprot.py | 633 ++++++++++++++++++++ 1 file changed, 633 insertions(+) create mode 100644 chebai/preprocessing/datasets/go_uniprot.py diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py new file mode 100644 index 00000000..108611ae --- /dev/null +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -0,0 +1,633 @@ +# Reference for this file : +# Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf; +# DeepGO: Predicting protein functions from sequence and interactions +# using a deep ontology-aware classifier, Bioinformatics, 2017. +# https://doi.org/10.1093/bioinformatics/btx624 +# Github: https://github.com/bio-ontology-research-group/deepgo +# https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt +# https://www.ebi.ac.uk/GOA/downloads + + +__all__ = ["GOUniprotDataModule"] + +import gzip +import os +from abc import ABC, abstractmethod +from collections import OrderedDict +from tempfile import NamedTemporaryFile, TemporaryDirectory, gettempdir +from typing import Any, Dict, Generator, List +from urllib import request + +import fastobo +import networkx as nx +import pandas as pd +import requests +import torch +from Bio import SwissProt +from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit + +from chebai.preprocessing.datasets import XYBaseDataModule + + +class _GOUniprotDataExtractor(XYBaseDataModule, ABC): + _GO_DATA_INIT = "GO" + + def __init__(self): + pass + + @property + def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: + """ + Property to retrieve dynamic train, validation, and test splits. + + This property checks if dynamic data splits (`dynamic_df_train`, `dynamic_df_val`, `dynamic_df_test`) + are already loaded. If any of them is None, it either generates them dynamically or retrieves them + from data file with help of pre-existing Split csv file (`splits_file_path`) containing splits assignments. + + Returns: + dict: A dictionary containing the dynamic train, validation, and test DataFrames. + Keys are 'train', 'validation', and 'test'. + """ + if any( + split is None + for split in [ + self.dynamic_df_test, + self.dynamic_df_val, + self.dynamic_df_train, + ] + ): + if self.splits_file_path is None: + # Generate splits based on given seed, create csv file to records the splits + self._generate_dynamic_splits() + else: + # If user has provided splits file path, use it to get the splits from the data + self._retrieve_splits_from_csv() + return { + "train": self.dynamic_df_train, + "validation": self.dynamic_df_val, + "test": self.dynamic_df_test, + } + + def _generate_dynamic_splits(self) -> None: + """ + Generate data splits during runtime and save them in class variables. + + This method loads encoded data derived from either `chebi_version` or `chebi_version_train` + and generates train, validation, and test splits based on the loaded data. + If `chebi_version_train` is specified, the test set is pruned to include only labels that + exist in `chebi_version_train`. + + Raises: + FileNotFoundError: If the required data file (`data.pt`) for either `chebi_version` or `chebi_version_train` + does not exist. It advises calling `prepare_data` or `setup` methods to generate + the dataset files. + """ + print("Generate dynamic splits...") + # Load encoded data derived from "chebi_version" + try: + filename = self.processed_file_names_dict["data"] + data_chebi_version = torch.load(os.path.join(self.processed_dir, filename)) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists. " + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_chebi_version = pd.DataFrame(data_chebi_version) + train_df_chebi_ver, df_test_chebi_ver = self.get_test_split( + df_chebi_version, seed=self.dynamic_data_split_seed + ) + + if self.chebi_version_train is not None: + # Load encoded data derived from "chebi_version_train" + try: + filename_train = ( + self._chebi_version_train_obj.processed_file_names_dict["data"] + ) + data_chebi_train_version = torch.load( + os.path.join( + self._chebi_version_train_obj.processed_dir, filename_train + ) + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists related to chebi_version_train {self.chebi_version_train}." + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_chebi_train_version = pd.DataFrame(data_chebi_train_version) + # Get train/val split of data based on "chebi_version_train", but + # using test set from "chebi_version" + df_train, df_val = self.get_train_val_splits_given_test( + df_chebi_train_version, + df_test_chebi_ver, + seed=self.dynamic_data_split_seed, + ) + # Modify test set from "chebi_version" to only include the labels that + # exists in "chebi_version_train", all other entries remains same. + df_test = self._setup_pruned_test_set(df_test_chebi_ver) + else: + # Get all splits based on "chebi_version" + df_train, df_val = self.get_train_val_splits_given_test( + train_df_chebi_ver, + df_test_chebi_ver, + seed=self.dynamic_data_split_seed, + ) + df_test = df_test_chebi_ver + + # Generate splits.csv file to store ids of each corresponding split + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": df_train["ident"], "split": "train"}), + pd.DataFrame({"id": df_val["ident"], "split": "validation"}), + pd.DataFrame({"id": df_test["ident"], "split": "test"}), + ] + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + combined_split_assignment.to_csv( + os.path.join(self.processed_dir_main, "splits.csv") + ) + + # Store the splits in class variables + self.dynamic_df_train = df_train + self.dynamic_df_val = df_val + self.dynamic_df_test = df_test + + def _retrieve_splits_from_csv(self) -> None: + """ + Retrieve previously saved data splits from splits.csv file or from provided file path. + + This method loads the splits.csv file located at `self.splits_file_path`. + It then loads the encoded data (`data.pt`) derived from `chebi_version` and filters + it based on the IDs retrieved from splits.csv to reconstruct the train, validation, + and test splits. + """ + print(f"Loading splits from {self.splits_file_path}...") + splits_df = pd.read_csv(self.splits_file_path) + + filename = self.processed_file_names_dict["data"] + data_chebi_version = torch.load(os.path.join(self.processed_dir, filename)) + df_chebi_version = pd.DataFrame(data_chebi_version) + + train_ids = splits_df[splits_df["split"] == "train"]["id"] + validation_ids = splits_df[splits_df["split"] == "validation"]["id"] + test_ids = splits_df[splits_df["split"] == "test"]["id"] + + self.dynamic_df_train = df_chebi_version[ + df_chebi_version["ident"].isin(train_ids) + ] + self.dynamic_df_val = df_chebi_version[ + df_chebi_version["ident"].isin(validation_ids) + ] + self.dynamic_df_test = df_chebi_version[ + df_chebi_version["ident"].isin(test_ids) + ] + + def get_test_split( + self, df: pd.DataFrame, seed: Optional[int] = None + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Split the input DataFrame into training and testing sets based on multilabel stratified sampling. + + This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels + in the training and testing sets is approximately the same. The split is based on the "labels" column + in the DataFrame. + + Args: + df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column + named "labels" with the multilabel data. + seed (int, optional): The random seed to be used for reproducibility. Default is None. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames. + + Raises: + ValueError: If the DataFrame does not contain a column named "labels". + """ + print("\nGet test data split") + + labels_list = df["labels"].tolist() + + test_size = 1 - self.train_split - (1 - self.train_split) ** 2 + msss = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + + train_indices, test_indices = next(msss.split(labels_list, labels_list)) + + df_train = df.iloc[train_indices] + df_test = df.iloc[test_indices] + return df_train, df_test + + def get_train_val_splits_given_test( + self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None + ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: + """ + Split the dataset into train and validation sets, given a test set. + Use test set (e.g., loaded from another chebi version or generated in get_test_split), to avoid overlap + + Args: + df (pd.DataFrame): The original dataset. + test_df (pd.DataFrame): The test dataset. + seed (int, optional): The random seed to be used for reproducibility. Default is None. + + Returns: + Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and + validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train + and validation DataFrames. The keys are the names of the train and validation sets, and the values + are the corresponding DataFrames. + """ + print(f"Split dataset into train / val with given test set") + + test_ids = test_df["ident"].tolist() + # ---- list comprehension degrades performance, dataframe operations are faster + # mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]] + # df_trainval = df_trainval[mask] + df_trainval = df[~df["ident"].isin(test_ids)] + labels_list_trainval = df_trainval["labels"].tolist() + + if self.use_inner_cross_validation: + folds = {} + kfold = MultilabelStratifiedKFold( + n_splits=self.inner_k_folds, random_state=seed + ) + for fold, (train_ids, val_ids) in enumerate( + kfold.split( + labels_list_trainval, + labels_list_trainval, + ) + ): + df_validation = df_trainval.iloc[val_ids] + df_train = df_trainval.iloc[train_ids] + folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train + folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = ( + df_validation + ) + + return folds + + # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) + test_size = ((1 - self.train_split) ** 2) / self.train_split + msss = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + + train_indices, validation_indices = next( + msss.split(labels_list_trainval, labels_list_trainval) + ) + + df_validation = df_trainval.iloc[validation_indices] + df_train = df_trainval.iloc[train_indices] + return df_train, df_validation + + def setup_processed(self): + print("Transform data") + os.makedirs(self.processed_dir, exist_ok=True) + + processed_name = self.processed_file_names_dict["data"] + if not os.path.isfile( + os.path.join(self.processed_dir, self.processed_dir_file_names["data"]) + ): + print("Missing transformed `data.pt` file. Transforming data.... ") + + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + self.processed_dir_main_file_names["data"], + ) + ), + os.path.join(self.processed_dir, self.processed_file_names["data"]), + ) + + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads a dictionary from a pickled file, yielding individual dictionaries for each row. + + Args: + input_file_path (str): The path to the file. + + Yields: + Dict[str, Any]: The dictionary, keys are `features`, `labels` and `ident`. + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + for row in df.values: + yield dict(features=row[2], labels=row[1], ident=row[0]) + + def prepare_data(self) -> None: + print("Checking for processed data in", self.processed_dir_main) + + if not os.path.isfile( + self.processed_dir_main, self.processed_dir_main_names_dict["GO"] + ): + print("Missing Gene Ontology processed data") + os.makedirs(self.processed_dir_main, exist_ok=True) + # swiss_path = self._download_swiss_uni_prot_data() + + go_path = self._download_gene_ontology_data() + g = self._extract_go_class_hierarchy(go_path) + data_df = self._graph_to_raw_dataset(g) + self.save_processed(data_df, self.processed_dir_main_file_names["data"]) + + @abstractmethod + def select_classes(self, g, *args, **kwargs): + raise NotImplementedError + + def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + """ + Preparation step before creating splits, uses the graph created by _extract_go_class_hierarchy(). + + Args: + g (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset created from the graph. + """ + names = nx.get_node_attributes(g, "name") + ids = nx.get_node_attributes(g, "id") + go_to_swiss_mapping = self._get_go_swiss_data_mapping() + + print(f"Processing graph") + + terms = list(g.nodes) + data = OrderedDict(id=terms) + + data_list = [] + for node in terms: + data_list.append( + ( + names.get(node), + ids.get(node), + go_to_swiss_mapping.get(ids.get(node))["sequence"], + go_to_swiss_mapping.get(ids.get(node))["swiss_ident"], + ) + ) + + names_list, ids_list, sequences_list, swiss_identifier_list = zip(*data_list) + + data["go_id"] = ids_list + data["name"] = names_list + data["sequence"] = sequences_list + data["swiss_ident"] = swiss_identifier_list + + # Assuming select_classes is implemented and returns a list of class IDs + for n in self.select_classes(g): + data[n] = [((n in g.predecessors(node)) or (n == node)) for node in terms] + + return pd.DataFrame(data) + + def _get_go_swiss_data_mapping(self) -> Dict[int : Dict[str:str]]: + # --------- --------------------------- ------------------------------ + # Line code Content Occurrence in an entry + # --------- --------------------------- ------------------------------ + # ID Identifier (keyword) Once; starts a keyword entry + # IC Identifier (category) Once; starts a category entry + # AC Accession (KW-xxxx) Once + # DE Definition Once or more + # SY Synonyms Optional; once or more + # GO Gene ontology (GO) mapping Optional; once or more + # HI Hierarchy Optional; once or more + # WW Relevant WWW site Optional; once or more + # CA Category Once per keyword entry; + # absent in category entries + # // Terminator Once; ends an entry + # --------------------------------------------------------------------------- + print("Parsing swiss uniprot raw data....") + + swiss_go_mapping = {} + swiss_data = SwissProt.parse( + open(self.raw_file_names_dict["SwissUniProt"], "r") + ) + + for record in swiss_data: + if record.data_class != "Reviewed": + # To consider only manually-annotated swiss data + continue + # Cross-reference has mapping for each protein to each type of data set + for cross_ref in record.cross_references: + if cross_ref[0] == self._GO_DATA_INIT: + # Only consider cross-reference related to GO dataset + go_id = _GOUniprotDataExtractor._parse_go_id(cross_ref[1]) + swiss_go_mapping[go_id] = { + "sequence": record.sequence, + "swiss_ident": record.entry_name, # Unique identifier for each swiss data record + } + return swiss_go_mapping + + def _extract_go_class_hierarchy(self, go_path: str) -> nx.DiGraph: + elements = [] + for term in fastobo.load(go_path): + if isinstance(term, fastobo.typedef.TypedefFrame): + # To avoid term frame of the below format/structure + # [Typedef] + # id: part_of + # name: part of + # namespace: external + # xref: BFO:0000050 + # is_transitive: true + continue + if ( + term + and isinstance(term.id, fastobo.id.PrefixedIdent) + and term.id.prefix == self._GO_DATA_INIT + ): + # Consider only terms with id in following format - GO:2001271 + term_dict = self.term_callback(term) + if term_dict: + elements.append(term_dict) + + g = nx.DiGraph() + for n in elements: + g.add_node(n["id"], **n) + g.add_edges_from([(p, q["id"]) for q in elements for p in q["parents"]]) + + print("Compute transitive closure") + g = nx.transitive_closure_dag(g) + # g = g.subgraph(list(nx.descendants(g, self.top_class_id)) + [self.top_class_id]) + return g + + @staticmethod + def term_callback(term: fastobo.term.TermFrame) -> dict: + """ + Extracts information from a Gene Ontology (GO) term document. + + Args: + term: A Gene Ontology term Frame document. + + Returns: + dict: A dictionary containing the extracted information: + - "id": The ID of the GO term. + - "parents": A list of parent term IDs. + - "name": The name of the GO term. + """ + parents = [] + name = None + + for clause in term: + if isinstance(clause, fastobo.term.IsAClause): + parents.append(_GOUniprotDataExtractor._parse_go_id(clause.term)) + elif isinstance(clause, fastobo.term.NameClause): + name = clause.name + elif isinstance(clause, fastobo.term.IsObsoleteClause): + if clause.obsolete: + # if the term contains clause as obsolete as true, skips this term + return None + + return { + "id": _GOUniprotDataExtractor._parse_go_id(term.id), + "parents": parents, + "name": name, + } + + @staticmethod + def _parse_go_id(go_id: str) -> int: + """ + Helper function to parse and normalize GO term IDs. + + Args: + go_id: The raw GO term ID string. + + Returns: + str: The parsed and normalized GO term ID. + """ + # `is_a` clause has GO id in the following format: + # is_a: GO:0009968 ! negative regulation of signal transduction + return int(str(go_id).split(":")[1].split("!")[0].strip()) + + def _download_gene_ontology_data(self) -> str: + """ + Download the Gene Ontology data `.obo` file. + + Note: + Quote from : https://geneontology.org/docs/download-ontology/ + Three versions of the ontology are available, the one use in this method is described below: + http://purl.obolibrary.org/obo/go/go-basic.obo + The basic version of the GO, filtered such that the graph is guaranteed to be acyclic and annotations + can be propagated up the graph. The relations included are `is a, part of, regulates, negatively` + `regulates` and `positively regulates`. This version excludes relationships that cross the 3 GO + hierarchies. This version should be used with most GO-based annotation tools. + + Returns: + str: The file path of the loaded Gene Ontology data. + """ + go_path = os.path.join(self.raw_dir, self.raw_file_names_dict["GO"]) + os.makedirs(os.path.dirname(go_path), exist_ok=True) + + if not os.path.isfile(go_path): + print("Missing Gene Ontology raw data") + print(f"Downloading Gene Ontology data....") + url = f"http://purl.obolibrary.org/obo/go/go-basic.obo" + r = requests.get(url, allow_redirects=True) + r.raise_for_status() # Check if the request was successful + open(go_path, "wb").write(r.content) + return go_path + + def _download_swiss_uni_prot_data(self) -> str: + """ + Download the Swiss-Prot data file from UniProt Knowledgebase. + + Note: + UniProt Knowledgebase is collection of functional information on proteins, with accurate, consistent + and rich annotation. + + Swiss-Prot contains manually-annotated records with information extracted from literature and + curator-evaluated computational analysis. + + Returns: + str: The file path of the loaded Swiss-Prot data file. + """ + uni_prot_file_path = os.path.join( + self.raw_dir, self.raw_file_names_dict["SwissUniProt"] + ) + os.makedirs(os.path.dirname(uni_prot_file_path), exist_ok=True) + temp_dir = gettempdir() + + if not os.path.isfile(uni_prot_file_path): + print(f"Downloading Swiss UniProt data....") + url = f"https://ftp.uniprot.org/pub/databases/uniprot/knowledgebase/complete/uniprot_sprot.dat.gz" + # TODO : Permission error, manually extracted the data as of now + temp_file_path = os.path.join(temp_dir, "uniprot_sprot.dat.gz") + try: + # Download the gzip file + request.urlretrieve(url, temp_file_path) + print(f"Downloaded to temporary file: {temp_file_path}") + + # Extract the gzip file + with gzip.open(temp_file_path, "rb") as gfile: + file_content = gfile.read() + print("Extracted the content from the gzip file.") + + # Decode and write the contents to the target file + with open(uni_prot_file_path, "wt", encoding="utf-8") as fout: + fout.write(file_content.decode("utf-8")) + print(f"Data written to: {uni_prot_file_path}") + + except PermissionError as e: + print(f"PermissionError: {e}") + return None + except Exception as e: + print(f"An error occurred: {e}") + return None + finally: + # Clean up the temporary file + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + print(f"Temporary file {temp_file_path} removed.") + + return uni_prot_file_path + + def select_classes(self, g, split_name, *args, **kwargs): + raise NotImplementedError + + def save_processed(self, data: pd.DataFrame, filename: str) -> None: + """ + Save the processed dataset to a pickle file. + + Args: + data (pd.DataFrame): The processed dataset to be saved. + filename (str): The filename for the pickle file. + """ + pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) + + @staticmethod + def _get_data_size(input_file_path: str) -> int: + """ + Get the size of the data from a pickled file. + + Args: + input_file_path (str): The path to the file. + + Returns: + int: The size of the data. + """ + with open(input_file_path, "rb") as f: + return len(pd.read_pickle(f)) + + @property + def raw_file_names_dict(self) -> dict: + return {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"} + + @property + def base_dir(self): + return os.path.join("data", f"Go_UniProt") + + @property + def processed_dir_main(self): + return os.path.join( + self.base_dir, + self._name, + "processed", + ) + + @property + def processed_dir_main_file_names(self) -> dict: + return {"data": "data.pkl"} + + @property + def processed_file_names(self) -> dict: + return {"data": "data.pt"} + + +class GOUniprotDataModule(_GOUniprotDataExtractor): + @property + def _name(self): + return f"GoUniProt_v1" From 484438072a90fddde6a50b92a9a4ce1431a6d9ca Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 21 Jul 2024 18:12:09 +0200 Subject: [PATCH 02/30] prepare_data : sequence added to graph creation process --- chebai/preprocessing/datasets/go_uniprot.py | 250 ++++++++++++++++---- 1 file changed, 209 insertions(+), 41 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 108611ae..39f75a82 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -8,14 +8,14 @@ # https://www.ebi.ac.uk/GOA/downloads -__all__ = ["GOUniprotDataModule"] +# __all__ = ["_GOUniprotDataModule"] import gzip import os from abc import ABC, abstractmethod from collections import OrderedDict from tempfile import NamedTemporaryFile, TemporaryDirectory, gettempdir -from typing import Any, Dict, Generator, List +from typing import Any, Dict, Generator, List, Optional, Tuple, Union from urllib import request import fastobo @@ -24,16 +24,71 @@ import requests import torch from Bio import SwissProt -from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit +from iterstrat.ml_stratifiers import ( + MultilabelStratifiedKFold, + MultilabelStratifiedShuffleSplit, +) +from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets import XYBaseDataModule class _GOUniprotDataExtractor(XYBaseDataModule, ABC): + """ + A class for extracting and processing data from the ChEBI dataset. + + Args: + chebi_version_train (int, optional): The version of ChEBI to use for training and validation. If not set, + chebi_version will be used for training, validation and test. Defaults to None. + single_class (int, optional): The ID of the single class to predict. If not set, all available labels will be + predicted. Defaults to None. + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + **kwargs: Additional keyword arguments (passed to XYBaseDataModule). + + Attributes: + single_class (Optional[int]): The ID of the single class to predict. + chebi_version_train (Optional[int]): The version of ChEBI to use for training and validation. + dynamic_data_split_seed (int): The seed for random data splitting, default is 42. + dynamic_df_train (Optional[pd.DataFrame]): DataFrame to store the training data split. + dynamic_df_test (Optional[pd.DataFrame]): DataFrame to store the test data split. + dynamic_df_val (Optional[pd.DataFrame]): DataFrame to store the validation data split. + splits_file_path (Optional[str]): Path to csv file containing split assignments. + """ + _GO_DATA_INIT = "GO" - def __init__(self): - pass + def __init__( + self, + # chebi_version_train: Optional[int] = None, + # single_class: Optional[int] = None, + **kwargs, + ): + # predict only single class (given as id of one of the classes present in the raw data set) + # self.single_class = single_class + super(_GOUniprotDataExtractor, self).__init__(**kwargs) + # use different version of chebi for training and validation (if not None) + # (still uses self.chebi_version for test set) + # self.chebi_version_train = chebi_version_train + self.dynamic_data_split_seed = int(kwargs.get("seed", 42)) # default is 42 + # Class variables to store the dynamics splits + self.dynamic_df_train = None + self.dynamic_df_test = None + self.dynamic_df_val = None + + # if self.chebi_version_train is not None: + # # Instantiate another same class with "chebi_version" as "chebi_version_train", if train_version is given + # # This is to get the data from respective directory related to "chebi_version_train" + # _init_kwargs = kwargs + # _init_kwargs["chebi_version"] = self.chebi_version_train + # self._chebi_version_train_obj = self.__class__( + # single_class=self.single_class, + # **_init_kwargs, + # ) + # Path of csv file which contains a list of chebi ids & their assignment to a dataset (either train, validation or test). + # self.splits_file_path = self._validate_splits_file_path( + # kwargs.get("splits_file_path", None) + # ) @property def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: @@ -283,19 +338,17 @@ def setup_processed(self): os.makedirs(self.processed_dir, exist_ok=True) processed_name = self.processed_file_names_dict["data"] - if not os.path.isfile( - os.path.join(self.processed_dir, self.processed_dir_file_names["data"]) - ): + if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): print("Missing transformed `data.pt` file. Transforming data.... ") torch.save( self._load_data_from_file( os.path.join( self.processed_dir_main, - self.processed_dir_main_file_names["data"], + self.processed_dir_main_file_names_dict["data"], ) ), - os.path.join(self.processed_dir, self.processed_file_names["data"]), + os.path.join(self.processed_dir, processed_name), ) def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: @@ -316,9 +369,8 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No def prepare_data(self) -> None: print("Checking for processed data in", self.processed_dir_main) - if not os.path.isfile( - self.processed_dir_main, self.processed_dir_main_names_dict["GO"] - ): + processed_name = self.processed_dir_main_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): print("Missing Gene Ontology processed data") os.makedirs(self.processed_dir_main, exist_ok=True) # swiss_path = self._download_swiss_uni_prot_data() @@ -326,10 +378,10 @@ def prepare_data(self) -> None: go_path = self._download_gene_ontology_data() g = self._extract_go_class_hierarchy(go_path) data_df = self._graph_to_raw_dataset(g) - self.save_processed(data_df, self.processed_dir_main_file_names["data"]) + self.save_processed(data_df, processed_name) @abstractmethod - def select_classes(self, g, *args, **kwargs): + def select_classes(self, g: nx.DiGraph, *args, **kwargs): raise NotImplementedError def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: @@ -342,40 +394,40 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: Returns: pd.DataFrame: The raw dataset created from the graph. """ + sequences = nx.get_node_attributes(g, "sequence") names = nx.get_node_attributes(g, "name") - ids = nx.get_node_attributes(g, "id") - go_to_swiss_mapping = self._get_go_swiss_data_mapping() + swiss_idents = nx.get_node_attributes(g, "swiss_ident") print(f"Processing graph") - terms = list(g.nodes) - data = OrderedDict(id=terms) - data_list = [] - for node in terms: - data_list.append( - ( - names.get(node), - ids.get(node), - go_to_swiss_mapping.get(ids.get(node))["sequence"], - go_to_swiss_mapping.get(ids.get(node))["swiss_ident"], + for node_id, sequence in sequences.items(): + if sequence: + data_list.append( + ( + node_id, + names.get(node_id), + sequence, + swiss_idents.get(node_id), + ) ) - ) - names_list, ids_list, sequences_list, swiss_identifier_list = zip(*data_list) + node_ids, names_list, sequences_list, swiss_identifier_list = zip(*data_list) + data = OrderedDict(id=node_ids) - data["go_id"] = ids_list data["name"] = names_list data["sequence"] = sequences_list data["swiss_ident"] = swiss_identifier_list # Assuming select_classes is implemented and returns a list of class IDs for n in self.select_classes(g): - data[n] = [((n in g.predecessors(node)) or (n == node)) for node in terms] + data[n] = [ + ((n in g.predecessors(node)) or (n == node)) for node in node_ids + ] return pd.DataFrame(data) - def _get_go_swiss_data_mapping(self) -> Dict[int : Dict[str:str]]: + def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: # --------- --------------------------- ------------------------------ # Line code Content Occurrence in an entry # --------- --------------------------- ------------------------------ @@ -395,7 +447,10 @@ def _get_go_swiss_data_mapping(self) -> Dict[int : Dict[str:str]]: swiss_go_mapping = {} swiss_data = SwissProt.parse( - open(self.raw_file_names_dict["SwissUniProt"], "r") + open( + os.path.join(self.raw_dir, self.raw_file_names_dict["SwissUniProt"]), + "r", + ) ) for record in swiss_data: @@ -425,6 +480,7 @@ def _extract_go_class_hierarchy(self, go_path: str) -> nx.DiGraph: # xref: BFO:0000050 # is_transitive: true continue + if ( term and isinstance(term.id, fastobo.id.PrefixedIdent) @@ -435,15 +491,19 @@ def _extract_go_class_hierarchy(self, go_path: str) -> nx.DiGraph: if term_dict: elements.append(term_dict) + go_to_swiss_mapping = self._get_go_swiss_data_mapping() + g = nx.DiGraph() for n in elements: - g.add_node(n["id"], **n) + node_mapping_dict = go_to_swiss_mapping.get(n["id"], {}) + # Combine the dictionaries for node attributes + node_attributes = {**n, **node_mapping_dict} + g.add_node(n["id"], **node_attributes) g.add_edges_from([(p, q["id"]) for q in elements for p in q["parents"]]) print("Compute transitive closure") - g = nx.transitive_closure_dag(g) # g = g.subgraph(list(nx.descendants(g, self.top_class_id)) + [self.top_class_id]) - return g + return nx.transitive_closure_dag(g) @staticmethod def term_callback(term: fastobo.term.TermFrame) -> dict: @@ -619,15 +679,123 @@ def processed_dir_main(self): ) @property - def processed_dir_main_file_names(self) -> dict: + def processed_dir_main_file_names_dict(self) -> dict: return {"data": "data.pkl"} @property - def processed_file_names(self) -> dict: + def processed_file_names_dict(self) -> dict: return {"data": "data.pt"} -class GOUniprotDataModule(_GOUniprotDataExtractor): +class _GoUniProtOverX(_GOUniprotDataExtractor, ABC): + """ + A class for extracting data from the ChEBI dataset with a threshold for selecting classes. + + Attributes: + LABEL_INDEX (int): The index of the label in the dataset. + SMILES_INDEX (int): The index of the SMILES string in the dataset. + READER (ChemDataReader): The reader used for reading the dataset. + THRESHOLD (None): The threshold for selecting classes. + """ + + LABEL_INDEX: int = 3 + SMILES_INDEX: int = 2 + READER: dr.ChemDataReader = dr.ChemDataReader + + THRESHOLD: int = None + @property - def _name(self): - return f"GoUniProt_v1" + @abstractmethod + def label_number(self) -> int: + raise NotImplementedError + + @property + def _name(self) -> str: + """ + Returns the name of the dataset. + + Returns: + str: The dataset name. + """ + return f"GoUniProt_OverX" + + def select_classes(self, g: nx.Graph, *args, **kwargs) -> List: + """ + Selects classes from the ChEBI dataset. + + Args: + g (nx.Graph): The graph representing the dataset. + go_to_swiss_mapping: Mapping from GO data to Swiss UniProt data. + *args: Additional arguments (not used). + **kwargs: Additional keyword arguments (not used). + + Returns: + list: The list of selected classes. + """ + sequences = nx.get_node_attributes(g, "sequence") + nodes = [] + for node in g.nodes: + # Counts the number of successors (child nodes) for each node and takes into account only the nodes + # with successors more than certain threshold for given node + no_of_successors = 0 + for s_node in g.successors(node): + if sequences.get(s_node, None): + no_of_successors += 1 + + if no_of_successors >= self.THRESHOLD: + nodes.append(node) + + nodes.sort() + + filename = "classes.txt" + with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: + fout.writelines(str(node) + "\n" for node in nodes) + return nodes + + +class GoUniProtOver100(_GoUniProtOverX): + """ + A class for extracting data from the ChEBI dataset with a threshold of 100 for selecting classes. + + Inherits from ChEBIOverX. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (100). + """ + + THRESHOLD: int = 100 + + def label_number(self) -> int: + """ + Returns the number of labels in the dataset. + + Overrides the base class method to return the correct number of labels for this threshold. + + Returns: + int: The number of labels. + """ + return 854 + + +class GoUniProtOver50(_GoUniProtOverX): + """ + A class for extracting data from the ChEBI dataset with a threshold of 50 for selecting classes. + + Inherits from ChEBIOverX. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (50). + """ + + THRESHOLD: int = 50 + + def label_number(self) -> int: + """ + Returns the number of labels in the dataset. + + Overrides the base class method to return the correct number of labels for this threshold. + + Returns: + int: The number of labels. + """ + return 1332 From 795c017d0e40e11e03f5d62ef525297428412692 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 22 Jul 2024 01:46:18 +0200 Subject: [PATCH 03/30] prepare_data: filter out any rows without any True value --- chebai/preprocessing/datasets/go_uniprot.py | 37 +++++++++++++++++---- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 39f75a82..ad962713 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -400,6 +400,7 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: print(f"Processing graph") + # Gets list of node ids, names, sequences, swiss identifier where sequence is not empty/None. data_list = [] for node_id, sequence in sequences.items(): if sequence: @@ -425,7 +426,11 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: ((n in g.predecessors(node)) or (n == node)) for node in node_ids ] - return pd.DataFrame(data) + data = pd.DataFrame(data) + # This filters the DataFrame to include only the rows where at least one value in the row from 5th column + # onwards is True/non-zero. + data = data[data.iloc[:, 4:].any(axis=1)] + return data def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: # --------- --------------------------- ------------------------------ @@ -686,6 +691,10 @@ def processed_dir_main_file_names_dict(self) -> dict: def processed_file_names_dict(self) -> dict: return {"data": "data.pt"} + @property + def processed_file_names(self) -> List[str]: + return list(self.processed_file_names_dict.values()) + class _GoUniProtOverX(_GOUniprotDataExtractor, ABC): """ @@ -721,16 +730,31 @@ def _name(self) -> str: def select_classes(self, g: nx.Graph, *args, **kwargs) -> List: """ - Selects classes from the ChEBI dataset. + Selects classes from the GO dataset based on the number of successors meeting a specified threshold. + + This method iterates over the nodes in the graph, counting the number of successors for each node. + Nodes with a number of successors greater than or equal to the defined threshold are selected. Args: - g (nx.Graph): The graph representing the dataset. - go_to_swiss_mapping: Mapping from GO data to Swiss UniProt data. - *args: Additional arguments (not used). + g (nx.Graph): The graph representing the dataset. Each node should have a 'sequence' attribute. + *args: Additional positional arguments (not used). **kwargs: Additional keyword arguments (not used). Returns: - list: The list of selected classes. + List: A sorted list of node IDs that meet the successor threshold criteria. + + Side Effects: + Writes the list of selected nodes to a file named "classes.txt" in the specified processed directory. + + Example: + To use this method, ensure the graph `g` is populated with nodes that have the 'sequence' attribute. + Call the method with the graph as the argument: + + selected_classes = my_instance.select_classes(graph) + + Notes: + - The `THRESHOLD` attribute should be defined in the class. + - Nodes without a 'sequence' attribute are ignored in the successor count. """ sequences = nx.get_node_attributes(g, "sequence") nodes = [] @@ -747,6 +771,7 @@ def select_classes(self, g: nx.Graph, *args, **kwargs) -> List: nodes.sort() + # Write the selected node ids / classes to the file filename = "classes.txt" with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: fout.writelines(str(node) + "\n" for node in nodes) From 4f06b62dfceb1d491dfd1f34215e9e196ddee1e0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 26 Jul 2024 00:59:53 +0200 Subject: [PATCH 04/30] setup data phase : preprocessing --- chebai/preprocessing/datasets/go_uniprot.py | 233 +++++++++++++------- 1 file changed, 151 insertions(+), 82 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index ad962713..9c2ae3b8 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -50,9 +50,9 @@ class _GOUniprotDataExtractor(XYBaseDataModule, ABC): single_class (Optional[int]): The ID of the single class to predict. chebi_version_train (Optional[int]): The version of ChEBI to use for training and validation. dynamic_data_split_seed (int): The seed for random data splitting, default is 42. - dynamic_df_train (Optional[pd.DataFrame]): DataFrame to store the training data split. - dynamic_df_test (Optional[pd.DataFrame]): DataFrame to store the test data split. - dynamic_df_val (Optional[pd.DataFrame]): DataFrame to store the validation data split. + _dynamic_df_train (Optional[pd.DataFrame]): DataFrame to store the training data split. + _dynamic_df_test (Optional[pd.DataFrame]): DataFrame to store the test data split. + _dynamic_df_val (Optional[pd.DataFrame]): DataFrame to store the validation data split. splits_file_path (Optional[str]): Path to csv file containing split assignments. """ @@ -72,9 +72,9 @@ def __init__( # self.chebi_version_train = chebi_version_train self.dynamic_data_split_seed = int(kwargs.get("seed", 42)) # default is 42 # Class variables to store the dynamics splits - self.dynamic_df_train = None - self.dynamic_df_test = None - self.dynamic_df_val = None + self._dynamic_df_train = None + self._dynamic_df_test = None + self._dynamic_df_val = None # if self.chebi_version_train is not None: # # Instantiate another same class with "chebi_version" as "chebi_version_train", if train_version is given @@ -86,16 +86,57 @@ def __init__( # **_init_kwargs, # ) # Path of csv file which contains a list of chebi ids & their assignment to a dataset (either train, validation or test). - # self.splits_file_path = self._validate_splits_file_path( - # kwargs.get("splits_file_path", None) - # ) + self.splits_file_path = self._validate_splits_file_path( + kwargs.get("splits_file_path", None) + ) + + @staticmethod + def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: + """ + Validates the file in provided splits file path. + + Args: + splits_file_path (Optional[str]): Path to the splits CSV file. + + Returns: + Optional[str]: Validated splits file path if checks pass, None if splits_file_path is None. + + Raises: + FileNotFoundError: If the splits file does not exist. + ValueError: If the splits file is empty or missing required columns ('id' and/or 'split'), or not a CSV file. + """ + if splits_file_path is None: + return None + + if not os.path.isfile(splits_file_path): + raise FileNotFoundError(f"File {splits_file_path} does not exist") + + file_size = os.path.getsize(splits_file_path) + if file_size == 0: + raise ValueError(f"File {splits_file_path} is empty") + + # Check if the file has a CSV extension + if not splits_file_path.lower().endswith(".csv"): + raise ValueError(f"File {splits_file_path} is not a CSV file") + + # Read the first row of CSV file into a DataFrame + splits_df = pd.read_csv(splits_file_path, nrows=1) + + # Check if 'id' and 'split' columns are in the DataFrame + required_columns = {"id", "split"} + if not required_columns.issubset(splits_df.columns): + raise ValueError( + f"CSV file {splits_file_path} is missing required columns ('id' and/or 'split')." + ) + + return splits_file_path @property def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: """ Property to retrieve dynamic train, validation, and test splits. - This property checks if dynamic data splits (`dynamic_df_train`, `dynamic_df_val`, `dynamic_df_test`) + This property checks if dynamic data splits (`_dynamic_df_train`, `_dynamic_df_val`, `_dynamic_df_test`) are already loaded. If any of them is None, it either generates them dynamically or retrieves them from data file with help of pre-existing Split csv file (`splits_file_path`) containing splits assignments. @@ -106,9 +147,9 @@ def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: if any( split is None for split in [ - self.dynamic_df_test, - self.dynamic_df_val, - self.dynamic_df_train, + self._dynamic_df_test, + self._dynamic_df_val, + self._dynamic_df_train, ] ): if self.splits_file_path is None: @@ -118,9 +159,9 @@ def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: # If user has provided splits file path, use it to get the splits from the data self._retrieve_splits_from_csv() return { - "train": self.dynamic_df_train, - "validation": self.dynamic_df_val, - "test": self.dynamic_df_test, + "train": self._dynamic_df_train, + "validation": self._dynamic_df_val, + "test": self._dynamic_df_test, } def _generate_dynamic_splits(self) -> None: @@ -153,42 +194,13 @@ def _generate_dynamic_splits(self) -> None: df_chebi_version, seed=self.dynamic_data_split_seed ) - if self.chebi_version_train is not None: - # Load encoded data derived from "chebi_version_train" - try: - filename_train = ( - self._chebi_version_train_obj.processed_file_names_dict["data"] - ) - data_chebi_train_version = torch.load( - os.path.join( - self._chebi_version_train_obj.processed_dir, filename_train - ) - ) - except FileNotFoundError: - raise FileNotFoundError( - f"File data.pt doesn't exists related to chebi_version_train {self.chebi_version_train}." - f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" - ) - - df_chebi_train_version = pd.DataFrame(data_chebi_train_version) - # Get train/val split of data based on "chebi_version_train", but - # using test set from "chebi_version" - df_train, df_val = self.get_train_val_splits_given_test( - df_chebi_train_version, - df_test_chebi_ver, - seed=self.dynamic_data_split_seed, - ) - # Modify test set from "chebi_version" to only include the labels that - # exists in "chebi_version_train", all other entries remains same. - df_test = self._setup_pruned_test_set(df_test_chebi_ver) - else: - # Get all splits based on "chebi_version" - df_train, df_val = self.get_train_val_splits_given_test( - train_df_chebi_ver, - df_test_chebi_ver, - seed=self.dynamic_data_split_seed, - ) - df_test = df_test_chebi_ver + # Get all splits based on "chebi_version" + df_train, df_val = self.get_train_val_splits_given_test( + train_df_chebi_ver, + df_test_chebi_ver, + seed=self.dynamic_data_split_seed, + ) + df_test = df_test_chebi_ver # Generate splits.csv file to store ids of each corresponding split split_assignment_list: List[pd.DataFrame] = [ @@ -202,9 +214,9 @@ def _generate_dynamic_splits(self) -> None: ) # Store the splits in class variables - self.dynamic_df_train = df_train - self.dynamic_df_val = df_val - self.dynamic_df_test = df_test + self._dynamic_df_train = df_train + self._dynamic_df_val = df_val + self._dynamic_df_test = df_test def _retrieve_splits_from_csv(self) -> None: """ @@ -226,13 +238,13 @@ def _retrieve_splits_from_csv(self) -> None: validation_ids = splits_df[splits_df["split"] == "validation"]["id"] test_ids = splits_df[splits_df["split"] == "test"]["id"] - self.dynamic_df_train = df_chebi_version[ + self._dynamic_df_train = df_chebi_version[ df_chebi_version["ident"].isin(train_ids) ] - self.dynamic_df_val = df_chebi_version[ + self._dynamic_df_val = df_chebi_version[ df_chebi_version["ident"].isin(validation_ids) ] - self.dynamic_df_test = df_chebi_version[ + self._dynamic_df_test = df_chebi_version[ df_chebi_version["ident"].isin(test_ids) ] @@ -336,20 +348,16 @@ def get_train_val_splits_given_test( def setup_processed(self): print("Transform data") os.makedirs(self.processed_dir, exist_ok=True) - - processed_name = self.processed_file_names_dict["data"] - if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - print("Missing transformed `data.pt` file. Transforming data.... ") - - torch.save( - self._load_data_from_file( - os.path.join( - self.processed_dir_main, - self.processed_dir_main_file_names_dict["data"], - ) - ), - os.path.join(self.processed_dir, processed_name), - ) + print("Missing transformed `data.pt` file. Transforming data.... ") + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + self.processed_dir_main_file_names_dict["data"], + ) + ), + os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), + ) def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: """ @@ -363,15 +371,23 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No """ with open(input_file_path, "rb") as input_file: df = pd.read_pickle(input_file) + # "id" at row index 0 + # "name" at row index 1 + # "sequence" at row index 2 + # "swiss_ident" at row index 3 + # labels starting from row index 4 for row in df.values: - yield dict(features=row[2], labels=row[1], ident=row[0]) + labels = row[4:].astype(bool) + # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group + # "group" set to None, by default as no such entity for this data + yield dict(features=row[2], labels=labels, ident=row[0]) def prepare_data(self) -> None: print("Checking for processed data in", self.processed_dir_main) processed_name = self.processed_dir_main_file_names_dict["data"] - if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - print("Missing Gene Ontology processed data") + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): + print("Missing Gene Ontology processed data (`data.pkl` file)") os.makedirs(self.processed_dir_main, exist_ok=True) # swiss_path = self._download_swiss_uni_prot_data() @@ -381,13 +397,34 @@ def prepare_data(self) -> None: self.save_processed(data_df, processed_name) @abstractmethod - def select_classes(self, g: nx.DiGraph, *args, **kwargs): + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + """ + Selects classes from the GO dataset based on a specified criteria. + + Args: + g (nx.Graph): The graph representing the dataset. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + List: A sorted list of node IDs that meet the specified criteria. + + """ raise NotImplementedError def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: """ - Preparation step before creating splits, uses the graph created by _extract_go_class_hierarchy(). - + Preparation step before creating splits, + uses the graph created by _extract_go_class_hierarchy() to extract the + raw data in Dataframe format with extra columns corresponding to each multi-label class. + + Data Format: pd.DataFrame + - Column 0 : ID (Identifier from GO dataset) + - Column 1 : Name of the protein + - Column 2 : Sequence representation of the protein + - Column 3 : Unique identifier of the protein from swiss dataset. + - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating where the + data instance belong to this class or not. Args: g (nx.DiGraph): The class hierarchy graph. @@ -433,6 +470,14 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: return data def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: + """ + Parse the swiss protein data and returns a mapping from GO data ID to swiss ID along with sequence + representation of the protein. + This mapping is needs as the GO data does not have the representation for protein sequence. + + Returns: + + """ # --------- --------------------------- ------------------------------ # Line code Content Occurrence in an entry # --------- --------------------------- ------------------------------ @@ -474,10 +519,19 @@ def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: return swiss_go_mapping def _extract_go_class_hierarchy(self, go_path: str) -> nx.DiGraph: + """ + Extracts the class hierarchy from the GO ontology. + + Args: + go_path (str): The path to the GO ontology. + + Returns: + nx.DiGraph: The class hierarchy. + """ elements = [] for term in fastobo.load(go_path): if isinstance(term, fastobo.typedef.TypedefFrame): - # To avoid term frame of the below format/structure + # ---- To avoid term frame of the below format/structure ---- # [Typedef] # id: part_of # name: part of @@ -640,9 +694,6 @@ def _download_swiss_uni_prot_data(self) -> str: return uni_prot_file_path - def select_classes(self, g, split_name, *args, **kwargs): - raise NotImplementedError - def save_processed(self, data: pd.DataFrame, filename: str) -> None: """ Save the processed dataset to a pickle file. @@ -673,6 +724,7 @@ def raw_file_names_dict(self) -> dict: @property def base_dir(self): + # All the data related to GO-Uniprot will be stored in data/Go_UniProt return os.path.join("data", f"Go_UniProt") @property @@ -683,6 +735,13 @@ def processed_dir_main(self): "processed", ) + @property + def processed_dir(self) -> str: + return os.path.join( + self.processed_dir_main, + *self.identifier, + ) + @property def processed_dir_main_file_names_dict(self) -> dict: return {"data": "data.pkl"} @@ -814,6 +873,16 @@ class GoUniProtOver50(_GoUniProtOverX): THRESHOLD: int = 50 + @property + def _name(self) -> str: + """ + Returns the name of the dataset. + + Returns: + str: The dataset name. + """ + return f"GoUniProt_OverX" + def label_number(self) -> int: """ Returns the number of labels in the dataset. From 13679752f093770b4201564f8d4a61d9ed3f9ec3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 26 Jul 2024 16:43:36 +0200 Subject: [PATCH 05/30] add reader for protein data --- chebai/preprocessing/reader.py | 42 ++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 902f1e92..777d64d9 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -320,3 +320,45 @@ def name(cls) -> str: def _read_data(self, raw_data: str) -> List[int]: """Convert characters in raw data to their ordinal values.""" return [ord(s) for s in raw_data] + + +class ProteinDataReader(DataReader): + """ + Data reader for Protein data using protein-sequence tokens. + + Args: + collator_kwargs: Optional dictionary of keyword arguments for the collator. + token_path: Optional path for the token file. + kwargs: Additional keyword arguments. + """ + + COLLATOR = RaggedCollator + + @classmethod + def name(cls) -> str: + """Returns the name of the data reader.""" + return "sequence_token" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + with open(self.token_path, "r") as pk: + self.cache = [x.strip() for x in pk] + + def _get_token_index(self, token: str) -> int: + """Returns a unique number for each token, automatically adds new tokens.""" + if not str(token) in self.cache: + self.cache.append(str(token)) + return self.cache.index(str(token)) + EMBEDDING_OFFSET + + def _read_data(self, raw_data: str) -> List[int]: + """Read and tokenize raw data.""" + return [self._get_token_index(v[1]) for v in _tokenize(raw_data)] + + def on_finish(self) -> None: + """Write contents of self.cache into tokens.txt.""" + with open(self.token_path, "w") as pk: + print(f"saving {len(self.cache)} tokens to {self.token_path}...") + print(f"first 3 sequences tokens: {self.cache[:3]}") + for token in self.cache[:3]: + print(f"Sequence Token: {token}") + pk.writelines([f"{c}\n" for c in self.cache]) From f2025791699d2d6b9c9f4038af66a419a33565c0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 26 Jul 2024 16:44:10 +0200 Subject: [PATCH 06/30] config : GO 50 --- configs/data/go50.yml | 1 + 1 file changed, 1 insertion(+) create mode 100644 configs/data/go50.yml diff --git a/configs/data/go50.yml b/configs/data/go50.yml new file mode 100644 index 00000000..a3e8ca60 --- /dev/null +++ b/configs/data/go50.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.go_uniprot.GoUniProtOver50 From a07c020a7f5c82db99ab9703ff5acdbdf5817535 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 26 Jul 2024 16:44:30 +0200 Subject: [PATCH 07/30] Update setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 6518d892..1a15a4fb 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ "chardet", "pyyaml", "torchmetrics", + "biopython", ], extras_require={"dev": ["black", "isort", "pre-commit"]}, ) From 07e511445266ccec8ba1bf1444825d86ded29469 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 26 Jul 2024 16:59:46 +0200 Subject: [PATCH 08/30] fix - local permission error for swiss data --- chebai/preprocessing/datasets/go_uniprot.py | 114 ++++++++++---------- 1 file changed, 55 insertions(+), 59 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 9c2ae3b8..5cc9d2e0 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -8,15 +8,15 @@ # https://www.ebi.ac.uk/GOA/downloads -# __all__ = ["_GOUniprotDataModule"] +__all__ = ["GoUniProtOver100", "GoUniProtOver50"] import gzip import os +import shutil from abc import ABC, abstractmethod from collections import OrderedDict -from tempfile import NamedTemporaryFile, TemporaryDirectory, gettempdir +from tempfile import NamedTemporaryFile from typing import Any, Dict, Generator, List, Optional, Tuple, Union -from urllib import request import fastobo import networkx as nx @@ -47,8 +47,6 @@ class _GOUniprotDataExtractor(XYBaseDataModule, ABC): **kwargs: Additional keyword arguments (passed to XYBaseDataModule). Attributes: - single_class (Optional[int]): The ID of the single class to predict. - chebi_version_train (Optional[int]): The version of ChEBI to use for training and validation. dynamic_data_split_seed (int): The seed for random data splitting, default is 42. _dynamic_df_train (Optional[pd.DataFrame]): DataFrame to store the training data split. _dynamic_df_test (Optional[pd.DataFrame]): DataFrame to store the test data split. @@ -57,6 +55,18 @@ class _GOUniprotDataExtractor(XYBaseDataModule, ABC): """ _GO_DATA_INIT = "GO" + # ---- Index for columns of processed `data.pkl` ------ + # "id" at row index 0 + # "name" at row index 1 + # "sequence" at row index 2 + # "swiss_ident" at row index 3 + # labels starting from row index 4 + _LABELS_STARTING_INDEX: int = 4 + _SEQUENCE_INDEX: int = 2 + _ID_INDEX = 0 + + _GO_DATA_URL = "http://purl.obolibrary.org/obo/go/go-basic.obo" + _SWISS_DATA_URL = "https://ftp.uniprot.org/pub/databases/uniprot/knowledgebase/complete/uniprot_sprot.dat.gz" def __init__( self, @@ -377,10 +387,14 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No # "swiss_ident" at row index 3 # labels starting from row index 4 for row in df.values: - labels = row[4:].astype(bool) + labels = row[self._LABELS_STARTING_INDEX :].astype(bool) # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group # "group" set to None, by default as no such entity for this data - yield dict(features=row[2], labels=labels, ident=row[0]) + yield dict( + features=row[self._SEQUENCE_INDEX], + labels=labels, + ident=row[self._ID_INDEX], + ) def prepare_data(self) -> None: print("Checking for processed data in", self.processed_dir_main) @@ -390,7 +404,7 @@ def prepare_data(self) -> None: print("Missing Gene Ontology processed data (`data.pkl` file)") os.makedirs(self.processed_dir_main, exist_ok=True) # swiss_path = self._download_swiss_uni_prot_data() - + self._download_swiss_uni_prot_data() go_path = self._download_gene_ontology_data() g = self._extract_go_class_hierarchy(go_path) data_df = self._graph_to_raw_dataset(g) @@ -466,7 +480,7 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: data = pd.DataFrame(data) # This filters the DataFrame to include only the rows where at least one value in the row from 5th column # onwards is True/non-zero. - data = data[data.iloc[:, 4:].any(axis=1)] + data = data[data.iloc[:, self._LABELS_STARTING_INDEX :].any(axis=1)] return data def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: @@ -565,7 +579,7 @@ def _extract_go_class_hierarchy(self, go_path: str) -> nx.DiGraph: return nx.transitive_closure_dag(g) @staticmethod - def term_callback(term: fastobo.term.TermFrame) -> dict: + def term_callback(term: fastobo.term.TermFrame) -> Optional[Dict]: """ Extracts information from a Gene Ontology (GO) term document. @@ -634,13 +648,12 @@ def _download_gene_ontology_data(self) -> str: if not os.path.isfile(go_path): print("Missing Gene Ontology raw data") print(f"Downloading Gene Ontology data....") - url = f"http://purl.obolibrary.org/obo/go/go-basic.obo" - r = requests.get(url, allow_redirects=True) + r = requests.get(self._GO_DATA_URL, allow_redirects=True) r.raise_for_status() # Check if the request was successful open(go_path, "wb").write(r.content) return go_path - def _download_swiss_uni_prot_data(self) -> str: + def _download_swiss_uni_prot_data(self) -> Optional[str]: """ Download the Swiss-Prot data file from UniProt Knowledgebase. @@ -658,39 +671,37 @@ def _download_swiss_uni_prot_data(self) -> str: self.raw_dir, self.raw_file_names_dict["SwissUniProt"] ) os.makedirs(os.path.dirname(uni_prot_file_path), exist_ok=True) - temp_dir = gettempdir() if not os.path.isfile(uni_prot_file_path): print(f"Downloading Swiss UniProt data....") - url = f"https://ftp.uniprot.org/pub/databases/uniprot/knowledgebase/complete/uniprot_sprot.dat.gz" - # TODO : Permission error, manually extracted the data as of now - temp_file_path = os.path.join(temp_dir, "uniprot_sprot.dat.gz") + + # Create a temporary file + with NamedTemporaryFile(delete=False) as tf: + temp_filename = tf.name + print(f"Downloading to temporary file {temp_filename}") + + # Download the file + response = requests.get(self._SWISS_DATA_URL, stream=True) + with open(temp_filename, "wb") as temp_file: + shutil.copyfileobj(response.raw, temp_file) + + print(f"Downloaded to {temp_filename}") + + # Unpack the gzipped file try: - # Download the gzip file - request.urlretrieve(url, temp_file_path) - print(f"Downloaded to temporary file: {temp_file_path}") - - # Extract the gzip file - with gzip.open(temp_file_path, "rb") as gfile: - file_content = gfile.read() - print("Extracted the content from the gzip file.") - - # Decode and write the contents to the target file - with open(uni_prot_file_path, "wt", encoding="utf-8") as fout: - fout.write(file_content.decode("utf-8")) - print(f"Data written to: {uni_prot_file_path}") - - except PermissionError as e: - print(f"PermissionError: {e}") - return None + print(f"Unzipping the file....") + with gzip.open(temp_filename, "rb") as f_in: + output_file_path = uni_prot_file_path + with open(output_file_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"Unpacked and saved to {output_file_path}") + except Exception as e: - print(f"An error occurred: {e}") - return None + print(f"Failed to unpack the file: {e}") finally: # Clean up the temporary file - if os.path.exists(temp_file_path): - os.remove(temp_file_path) - print(f"Temporary file {temp_file_path} removed.") + os.remove(temp_filename) + print(f"Removed temporary file {temp_filename}") return uni_prot_file_path @@ -724,8 +735,8 @@ def raw_file_names_dict(self) -> dict: @property def base_dir(self): - # All the data related to GO-Uniprot will be stored in data/Go_UniProt - return os.path.join("data", f"Go_UniProt") + # All the data related to GO-Uniprot will be stored in data/GO_UniProt + return os.path.join("data", f"GO_UniProt") @property def processed_dir_main(self): @@ -760,16 +771,11 @@ class _GoUniProtOverX(_GOUniprotDataExtractor, ABC): A class for extracting data from the ChEBI dataset with a threshold for selecting classes. Attributes: - LABEL_INDEX (int): The index of the label in the dataset. - SMILES_INDEX (int): The index of the SMILES string in the dataset. READER (ChemDataReader): The reader used for reading the dataset. THRESHOLD (None): The threshold for selecting classes. """ - LABEL_INDEX: int = 3 - SMILES_INDEX: int = 2 - READER: dr.ChemDataReader = dr.ChemDataReader - + READER: dr.ProteinDataReader = dr.ProteinDataReader THRESHOLD: int = None @property @@ -785,9 +791,9 @@ def _name(self) -> str: Returns: str: The dataset name. """ - return f"GoUniProt_OverX" + return f"GO{self.THRESHOLD}" - def select_classes(self, g: nx.Graph, *args, **kwargs) -> List: + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: """ Selects classes from the GO dataset based on the number of successors meeting a specified threshold. @@ -873,16 +879,6 @@ class GoUniProtOver50(_GoUniProtOverX): THRESHOLD: int = 50 - @property - def _name(self) -> str: - """ - Returns the name of the dataset. - - Returns: - str: The dataset name. - """ - return f"GoUniProt_OverX" - def label_number(self) -> int: """ Returns the number of labels in the dataset. From b3349290caf7168569ed06b9aacf3def48d63ab7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 29 Jul 2024 00:23:17 +0200 Subject: [PATCH 09/30] go_uniprot : docstrings + variable namings --- chebai/preprocessing/datasets/go_uniprot.py | 267 ++++++++++++-------- 1 file changed, 160 insertions(+), 107 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 5cc9d2e0..b5022514 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -4,7 +4,6 @@ # using a deep ontology-aware classifier, Bioinformatics, 2017. # https://doi.org/10.1093/bioinformatics/btx624 # Github: https://github.com/bio-ontology-research-group/deepgo -# https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt # https://www.ebi.ac.uk/GOA/downloads @@ -35,23 +34,16 @@ class _GOUniprotDataExtractor(XYBaseDataModule, ABC): """ - A class for extracting and processing data from the ChEBI dataset. + A class for extracting and processing data from the Gene Ontology (GO) dataset and the Swiss UniProt dataset. Args: - chebi_version_train (int, optional): The version of ChEBI to use for training and validation. If not set, - chebi_version will be used for training, validation and test. Defaults to None. - single_class (int, optional): The ID of the single class to predict. If not set, all available labels will be - predicted. Defaults to None. dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. - **kwargs: Additional keyword arguments (passed to XYBaseDataModule). + **kwargs: Additional keyword arguments passed to XYBaseDataModule. Attributes: dynamic_data_split_seed (int): The seed for random data splitting, default is 42. - _dynamic_df_train (Optional[pd.DataFrame]): DataFrame to store the training data split. - _dynamic_df_test (Optional[pd.DataFrame]): DataFrame to store the test data split. - _dynamic_df_val (Optional[pd.DataFrame]): DataFrame to store the validation data split. - splits_file_path (Optional[str]): Path to csv file containing split assignments. + splits_file_path (Optional[str]): Path to the CSV file containing split assignments. """ _GO_DATA_INIT = "GO" @@ -70,32 +62,16 @@ class _GOUniprotDataExtractor(XYBaseDataModule, ABC): def __init__( self, - # chebi_version_train: Optional[int] = None, - # single_class: Optional[int] = None, **kwargs, ): - # predict only single class (given as id of one of the classes present in the raw data set) - # self.single_class = single_class super(_GOUniprotDataExtractor, self).__init__(**kwargs) - # use different version of chebi for training and validation (if not None) - # (still uses self.chebi_version for test set) - # self.chebi_version_train = chebi_version_train self.dynamic_data_split_seed = int(kwargs.get("seed", 42)) # default is 42 # Class variables to store the dynamics splits self._dynamic_df_train = None self._dynamic_df_test = None self._dynamic_df_val = None - - # if self.chebi_version_train is not None: - # # Instantiate another same class with "chebi_version" as "chebi_version_train", if train_version is given - # # This is to get the data from respective directory related to "chebi_version_train" - # _init_kwargs = kwargs - # _init_kwargs["chebi_version"] = self.chebi_version_train - # self._chebi_version_train_obj = self.__class__( - # single_class=self.single_class, - # **_init_kwargs, - # ) - # Path of csv file which contains a list of chebi ids & their assignment to a dataset (either train, validation or test). + # Path of csv file which contains a list of go ids & their assignment to a dataset (either train, + # validation or test). self.splits_file_path = self._validate_splits_file_path( kwargs.get("splits_file_path", None) ) @@ -178,39 +154,34 @@ def _generate_dynamic_splits(self) -> None: """ Generate data splits during runtime and save them in class variables. - This method loads encoded data derived from either `chebi_version` or `chebi_version_train` - and generates train, validation, and test splits based on the loaded data. - If `chebi_version_train` is specified, the test set is pruned to include only labels that - exist in `chebi_version_train`. + This method loads encoded data generates train, validation, and test splits based on the loaded data. Raises: - FileNotFoundError: If the required data file (`data.pt`) for either `chebi_version` or `chebi_version_train` - does not exist. It advises calling `prepare_data` or `setup` methods to generate - the dataset files. + FileNotFoundError: If the required data file (`data.pt`) does not exist. It advises calling `prepare_data` + or `setup` methods to generate the dataset files. """ print("Generate dynamic splits...") - # Load encoded data derived from "chebi_version" + # Load encoded data try: filename = self.processed_file_names_dict["data"] - data_chebi_version = torch.load(os.path.join(self.processed_dir, filename)) + data_go = torch.load(os.path.join(self.processed_dir, filename)) except FileNotFoundError: raise FileNotFoundError( f"File data.pt doesn't exists. " f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" ) - df_chebi_version = pd.DataFrame(data_chebi_version) - train_df_chebi_ver, df_test_chebi_ver = self.get_test_split( - df_chebi_version, seed=self.dynamic_data_split_seed + df_go_data = pd.DataFrame(data_go) + train_df_go, df_test = self.get_test_split( + df_go_data, seed=self.dynamic_data_split_seed ) - # Get all splits based on "chebi_version" + # Get all splits df_train, df_val = self.get_train_val_splits_given_test( - train_df_chebi_ver, - df_test_chebi_ver, + train_df_go, + df_test, seed=self.dynamic_data_split_seed, ) - df_test = df_test_chebi_ver # Generate splits.csv file to store ids of each corresponding split split_assignment_list: List[pd.DataFrame] = [ @@ -233,30 +204,23 @@ def _retrieve_splits_from_csv(self) -> None: Retrieve previously saved data splits from splits.csv file or from provided file path. This method loads the splits.csv file located at `self.splits_file_path`. - It then loads the encoded data (`data.pt`) derived from `chebi_version` and filters - it based on the IDs retrieved from splits.csv to reconstruct the train, validation, - and test splits. + It then loads the encoded data (`data.pt`) and filters it based on the IDs retrieved from + splits.csv to reconstruct the train, validation, and test splits. """ print(f"Loading splits from {self.splits_file_path}...") splits_df = pd.read_csv(self.splits_file_path) filename = self.processed_file_names_dict["data"] - data_chebi_version = torch.load(os.path.join(self.processed_dir, filename)) - df_chebi_version = pd.DataFrame(data_chebi_version) + data_go = torch.load(os.path.join(self.processed_dir, filename)) + df_go_data = pd.DataFrame(data_go) train_ids = splits_df[splits_df["split"] == "train"]["id"] validation_ids = splits_df[splits_df["split"] == "validation"]["id"] test_ids = splits_df[splits_df["split"] == "test"]["id"] - self._dynamic_df_train = df_chebi_version[ - df_chebi_version["ident"].isin(train_ids) - ] - self._dynamic_df_val = df_chebi_version[ - df_chebi_version["ident"].isin(validation_ids) - ] - self._dynamic_df_test = df_chebi_version[ - df_chebi_version["ident"].isin(test_ids) - ] + self._dynamic_df_train = df_go_data[df_go_data["ident"].isin(train_ids)] + self._dynamic_df_val = df_go_data[df_go_data["ident"].isin(validation_ids)] + self._dynamic_df_test = df_go_data[df_go_data["ident"].isin(test_ids)] def get_test_split( self, df: pd.DataFrame, seed: Optional[int] = None @@ -299,7 +263,7 @@ def get_train_val_splits_given_test( ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: """ Split the dataset into train and validation sets, given a test set. - Use test set (e.g., loaded from another chebi version or generated in get_test_split), to avoid overlap + Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap Args: df (pd.DataFrame): The original dataset. @@ -315,9 +279,6 @@ def get_train_val_splits_given_test( print(f"Split dataset into train / val with given test set") test_ids = test_df["ident"].tolist() - # ---- list comprehension degrades performance, dataframe operations are faster - # mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]] - # df_trainval = df_trainval[mask] df_trainval = df[~df["ident"].isin(test_ids)] labels_list_trainval = df_trainval["labels"].tolist() @@ -355,7 +316,16 @@ def get_train_val_splits_given_test( df_train = df_trainval.iloc[train_indices] return df_train, df_validation - def setup_processed(self): + def setup_processed(self) -> None: + """ + Transforms `data.pkl` into a model input data format (`data.pt`), ensuring that the data is in a format + compatible for input to the model. + The transformed data contains the following keys: `ident`, `features`, `labels`, and `group`. + This method uses a subclass of Data Reader to perform the transformation. + + Returns: + None + """ print("Transform data") os.makedirs(self.processed_dir, exist_ok=True) print("Missing transformed `data.pt` file. Transforming data.... ") @@ -371,13 +341,24 @@ def setup_processed(self): def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: """ - Loads a dictionary from a pickled file, yielding individual dictionaries for each row. + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index 0: ID of go data instance + - Data at row index 2: Sequence representation of protein + - Data from row index 4 onwards: Labels + + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. Args: - input_file_path (str): The path to the file. + input_file_path (str): The path to the pickled input file. Yields: - Dict[str, Any]: The dictionary, keys are `features`, `labels` and `ident`. + Dict[str, Any]: A dictionary containing: + - `features` (str): The sequence data from the file. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. """ with open(input_file_path, "rb") as input_file: df = pd.read_pickle(input_file) @@ -396,7 +377,25 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No ident=row[self._ID_INDEX], ) - def prepare_data(self) -> None: + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Prepares the data for the Go dataset. + + This method checks for the presence of raw data in the specified directory. + If the raw data is missing, it fetches the ontology and creates a dataframe and saves it to a data.pkl file. + + The resulting dataframe/pickle file is expected to contain columns with the following structure: + - Column at index 0: ID of go data instance + - Column at index 2: Sequence representation of the protein + - Column from index 4 onwards: Labels + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + None + """ print("Checking for processed data in", self.processed_dir_main) processed_name = self.processed_dir_main_file_names_dict["data"] @@ -433,7 +432,7 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: raw data in Dataframe format with extra columns corresponding to each multi-label class. Data Format: pd.DataFrame - - Column 0 : ID (Identifier from GO dataset) + - Column 0 : ID (Identifier for GO data instance) - Column 1 : Name of the protein - Column 2 : Sequence representation of the protein - Column 3 : Unique identifier of the protein from swiss dataset. @@ -485,13 +484,18 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: """ - Parse the swiss protein data and returns a mapping from GO data ID to swiss ID along with sequence - representation of the protein. - This mapping is needs as the GO data does not have the representation for protein sequence. + Parses the Swiss-Prot data and returns a mapping from Gene Ontology (GO) data ID to Swiss-Prot ID + along with the sequence representation of the protein. - Returns: + This mapping is necessary because the GO data does not include the protein sequence representation. + Returns: + Dict[int, Dict[str, str]]: A dictionary where the keys are GO data IDs (int) and the values are + dictionaries containing: + - "sequence" (str): The protein sequence. + - "swiss_ident" (str): The unique identifier for each Swiss-Prot record. """ + # # https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt # --------- --------------------------- ------------------------------ # Line code Content Occurrence in an entry # --------- --------------------------- ------------------------------ @@ -535,12 +539,15 @@ def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: def _extract_go_class_hierarchy(self, go_path: str) -> nx.DiGraph: """ Extracts the class hierarchy from the GO ontology. + Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with GO term data + and corresponding Swiss-Prot data (obtained via `_get_go_swiss_data_mapping`). Args: go_path (str): The path to the GO ontology. Returns: - nx.DiGraph: The class hierarchy. + nx.DiGraph: A directed graph representing the class hierarchy, where nodes are GO terms and edges + represent parent-child relationships. """ elements = [] for term in fastobo.load(go_path): @@ -568,6 +575,7 @@ def _extract_go_class_hierarchy(self, go_path: str) -> nx.DiGraph: g = nx.DiGraph() for n in elements: + # Swiss data is mapped to respective go data instance node_mapping_dict = go_to_swiss_mapping.get(n["id"], {}) # Combine the dictionaries for node attributes node_attributes = {**n, **node_mapping_dict} @@ -575,22 +583,23 @@ def _extract_go_class_hierarchy(self, go_path: str) -> nx.DiGraph: g.add_edges_from([(p, q["id"]) for q in elements for p in q["parents"]]) print("Compute transitive closure") - # g = g.subgraph(list(nx.descendants(g, self.top_class_id)) + [self.top_class_id]) return nx.transitive_closure_dag(g) @staticmethod def term_callback(term: fastobo.term.TermFrame) -> Optional[Dict]: """ Extracts information from a Gene Ontology (GO) term document. + It also checks if the term is marked as obsolete and skips such terms. Args: term: A Gene Ontology term Frame document. Returns: - dict: A dictionary containing the extracted information: - - "id": The ID of the GO term. - - "parents": A list of parent term IDs. - - "name": The name of the GO term. + Optional[Dict]: A dictionary containing the extracted information if the term is not obsolete, + otherwise None. The dictionary includes: + - "id" (str): The ID of the GO term. + - "parents" (List[str]): A list of parent term IDs. + - "name" (str): The name of the GO term. """ parents = [] name = None @@ -731,15 +740,33 @@ def _get_data_size(input_file_path: str) -> int: @property def raw_file_names_dict(self) -> dict: + """ + Returns a dictionary of raw file names used in data processing. + + Returns: + dict: A dictionary mapping dataset names to their respective file names. + For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}. + """ return {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"} @property - def base_dir(self): - # All the data related to GO-Uniprot will be stored in data/GO_UniProt + def base_dir(self) -> str: + """ + Returns the base directory path for storing GO-Uniprot data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ return os.path.join("data", f"GO_UniProt") @property - def processed_dir_main(self): + def processed_dir_main(self) -> str: + """ + Returns the main directory path where processed data is stored. + + Returns: + str: The path to the main processed data directory, based on the base directory and the instance's name. + """ return os.path.join( self.base_dir, self._name, @@ -748,6 +775,12 @@ def processed_dir_main(self): @property def processed_dir(self) -> str: + """ + Returns the specific directory path for processed data, including identifiers. + + Returns: + str: The path to the processed data directory, including additional identifiers. + """ return os.path.join( self.processed_dir_main, *self.identifier, @@ -755,24 +788,51 @@ def processed_dir(self) -> str: @property def processed_dir_main_file_names_dict(self) -> dict: + """ + Returns a dictionary mapping processed data file names. + + Returns: + dict: A dictionary mapping dataset types to their respective processed file names. + For example, {"data": "data.pkl"}. + """ return {"data": "data.pkl"} @property def processed_file_names_dict(self) -> dict: + """ + Returns a dictionary mapping processed data file names to their final formats. + + Returns: + dict: A dictionary mapping dataset types to their respective final file names. + For example, {"data": "data.pt"}. + """ return {"data": "data.pt"} @property def processed_file_names(self) -> List[str]: + """ + Returns a list of file names for processed data. + + Returns: + List[str]: A list of file names corresponding to the processed data. + """ return list(self.processed_file_names_dict.values()) class _GoUniProtOverX(_GOUniprotDataExtractor, ABC): """ - A class for extracting data from the ChEBI dataset with a threshold for selecting classes. + A class for extracting data from the Gene Ontology (GO) dataset with a threshold for selecting classes based on + the number of subclasses. + + This class is designed to filter GO classes based on a specified threshold, selecting only those classes + which have a certain number of subclasses in the hierarchy. Attributes: - READER (ChemDataReader): The reader used for reading the dataset. - THRESHOLD (None): The threshold for selecting classes. + READER (dr.ProteinDataReader): The reader used for reading the dataset. + THRESHOLD (int): The threshold for selecting classes based on the number of subclasses. + + Property: + label_number (int): The number of labels in the dataset. This property must be implemented by subclasses. """ READER: dr.ProteinDataReader = dr.ProteinDataReader @@ -789,7 +849,7 @@ def _name(self) -> str: Returns the name of the dataset. Returns: - str: The dataset name. + str: The dataset name, formatted with the current threshold value. """ return f"GO{self.THRESHOLD}" @@ -801,7 +861,7 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: Nodes with a number of successors greater than or equal to the defined threshold are selected. Args: - g (nx.Graph): The graph representing the dataset. Each node should have a 'sequence' attribute. + g (nx.DiGraph): The graph representing the dataset. Each node should have a 'sequence' attribute. *args: Additional positional arguments (not used). **kwargs: Additional keyword arguments (not used). @@ -811,21 +871,14 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: Side Effects: Writes the list of selected nodes to a file named "classes.txt" in the specified processed directory. - Example: - To use this method, ensure the graph `g` is populated with nodes that have the 'sequence' attribute. - Call the method with the graph as the argument: - - selected_classes = my_instance.select_classes(graph) - Notes: - - The `THRESHOLD` attribute should be defined in the class. + - The `THRESHOLD` attribute should be defined in the subclass. - Nodes without a 'sequence' attribute are ignored in the successor count. """ sequences = nx.get_node_attributes(g, "sequence") nodes = [] for node in g.nodes: - # Counts the number of successors (child nodes) for each node and takes into account only the nodes - # with successors more than certain threshold for given node + # Count the number of successors (child nodes) for each node no_of_successors = 0 for s_node in g.successors(node): if sequences.get(s_node, None): @@ -836,7 +889,7 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: nodes.sort() - # Write the selected node ids / classes to the file + # Write the selected node IDs/classes to the file filename = "classes.txt" with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: fout.writelines(str(node) + "\n" for node in nodes) @@ -845,9 +898,9 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: class GoUniProtOver100(_GoUniProtOverX): """ - A class for extracting data from the ChEBI dataset with a threshold of 100 for selecting classes. + A class for extracting data from the Gene Ontology (GO) dataset with a threshold of 100 for selecting classes. - Inherits from ChEBIOverX. + Inherits from `_GoUniProtOverX` and sets the threshold for selecting classes to 100. Attributes: THRESHOLD (int): The threshold for selecting classes (100). @@ -857,21 +910,21 @@ class GoUniProtOver100(_GoUniProtOverX): def label_number(self) -> int: """ - Returns the number of labels in the dataset. + Returns the number of labels in the dataset for this threshold. - Overrides the base class method to return the correct number of labels for this threshold. + Overrides the base class method to provide the correct number of labels for a threshold of 100. Returns: - int: The number of labels. + int: The number of labels (854). """ return 854 class GoUniProtOver50(_GoUniProtOverX): """ - A class for extracting data from the ChEBI dataset with a threshold of 50 for selecting classes. + A class for extracting data from the Gene Ontology (GO) dataset with a threshold of 50 for selecting classes. - Inherits from ChEBIOverX. + Inherits from `_GoUniProtOverX` and sets the threshold for selecting classes to 50. Attributes: THRESHOLD (int): The threshold for selecting classes (50). @@ -881,11 +934,11 @@ class GoUniProtOver50(_GoUniProtOverX): def label_number(self) -> int: """ - Returns the number of labels in the dataset. + Returns the number of labels in the dataset for this threshold. - Overrides the base class method to return the correct number of labels for this threshold. + Overrides the base class method to provide the correct number of labels for a threshold of 50. Returns: - int: The number of labels. + int: The number of labels (1332). """ return 1332 From 5cdc9b8d17ea009108abb15bf249ec421df42e09 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 31 Jul 2024 15:13:17 +0200 Subject: [PATCH 10/30] chebi.py : additional/more specific docstrings --- chebai/preprocessing/datasets/chebi.py | 63 +++++++++++++++++++++----- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 5876577f..a1bfcf6b 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -159,7 +159,8 @@ def __init__( single_class=self.single_class, **_init_kwargs, ) - # Path of csv file which contains a list of chebi ids & their assignment to a dataset (either train, validation or test). + # Path of csv file which contains a list of chebi ids & their assignment to a dataset (either train, + # validation or test). self.splits_file_path = self._validate_splits_file_path( kwargs.get("splits_file_path", None) ) @@ -167,7 +168,7 @@ def __init__( @staticmethod def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: """ - Validates the provided splits file path. + Validates the file in provided splits file path. Args: splits_file_path (Optional[str]): Path to the splits CSV file. @@ -230,6 +231,7 @@ def extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: return nx.transitive_closure_dag(g) def select_classes(self, g, split_name, *args, **kwargs): + raise NotImplementedError def graph_to_raw_dataset( @@ -269,6 +271,8 @@ def graph_to_raw_dataset( data = pd.DataFrame(data) data = data[~data["SMILES"].isnull()] data = data[[name not in CHEBI_BLACKLIST for name, _ in data.iterrows()]] + # This filters the DataFrame to include only the rows where at least one value in the row from 4th column + # onwards is True/non-zero. data = data[data.iloc[:, 3:].any(axis=1)] return data @@ -296,11 +300,26 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No """ Loads a dictionary from a pickled file, yielding individual dictionaries for each row. + This method reads data from a specified pickled file, processes each row to extract relevant + information, and yields dictionaries containing the keys `features`, `labels`, and `ident`. + If `single_class` is specified, it only includes the label for that specific class; otherwise, + it includes labels for all classes starting from the fourth column. + + The pickled file is expected to contain rows with the following structure: + - Data at row index 0: ID of the chebi data instance + - Data at row index 2: SMILES representation for the chemical + - Data from row index 3 onwards: Labels + + This method is used in `_load_data_from_file` to process each row of data and convert it + into the desired dictionary format before loading it into the model. + Args: - input_file_path (str): The path to the file. + input_file_path (str): The path to the input pickled file. Yields: - Dict[str, Any]: The dictionary, keys are `features`, `labels` and `ident`. + Dict[str, Any]: A dictionary with keys `features`, `labels`, and `ident`. + `features` contains the sequence, `labels` contains the labels as boolean values, + and `ident` contains the identifier. """ with open(input_file_path, "rb") as input_file: df = pd.read_pickle(input_file) @@ -386,6 +405,11 @@ def setup_processed(self) -> None: """ Transform and prepare processed data for the ChEBI dataset. + Main function of this method is to transform `data.pkl` into a model input data format (`data.pt`), + ensuring that the data is in a format compatible for input to the model. + The transformed data must contain the following keys: `ident`, `features`, `labels`, and `group`. + This method uses a subclass of Data Reader to perform the transformation. + This method sets up the processed data directories and files based on the ChEBI version and train version (if specified). It ensures that the required processed data files exist by loading raw data, transforming it into processed format, and saving it. @@ -701,9 +725,12 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: Prepares the data for the Chebi dataset. This method checks for the presence of raw data in the specified directory. - If the raw data is missing, it fetches the ontology and creates a test set. - If the test set already exists, it loads it from the file. - Then, it creates the train/validation split based on the test set. + If the raw data is missing, it fetches the ontology and creates a dataframe & saves it as data.pkl pickle file. + + The resulting dataframe/pickle file is expected to contain columns with the following structure: + - Column at index 0: ID of chebi data instance + - Column at index 2: SMILES representation of the chemical + - Column from index 3 onwards: Labels Args: *args: Variable length argument list. @@ -981,6 +1008,8 @@ def select_classes(self, g, *args, **kwargs): class ChEBIOverX(_ChEBIDataExtractor): """ A class for extracting data from the ChEBI dataset with a threshold for selecting classes. + This class is designed to filter Chebi classes based on a specified threshold, selecting only those classes + which have a certain number of subclasses in the hierarchy. Attributes: LABEL_INDEX (int): The index of the label in the dataset. @@ -1014,18 +1043,28 @@ def _name(self) -> str: """ return f"ChEBI{self.THRESHOLD}" - def select_classes(self, g: nx.Graph, split_name: str, *args, **kwargs) -> List: + def select_classes(self, g: nx.DiGraph, split_name: str, *args, **kwargs) -> List: """ - Selects classes from the ChEBI dataset. + Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold. + + This method iterates over the nodes in the graph, counting the number of successors for each node. + Nodes with a number of successors greater than or equal to the defined threshold are selected. Args: g (nx.Graph): The graph representing the dataset. - split_name (str): The name of the split. - *args: Additional arguments (not used). + split_name (str) : Name of the split (not used). + *args: Additional positional arguments (not used). **kwargs: Additional keyword arguments (not used). Returns: - list: The list of selected classes. + List: A sorted list of node IDs that meet the successor threshold criteria. + + Side Effects: + Writes the list of selected nodes to a file named "classes.txt" in the specified processed directory. + + Notes: + - The `THRESHOLD` attribute should be defined in the subclass of this class. + - Nodes without a 'sequence' attribute are ignored in the successor count. """ smiles = nx.get_node_attributes(g, "smiles") nodes = list( From 0ee241ad817aabbf4cfc3f923a382626da070f9f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 3 Aug 2024 00:05:30 +0200 Subject: [PATCH 11/30] base class for datasets following new dynamics splits feature --- chebai/preprocessing/datasets/base.py | 507 +++++++++++++++++++++++++- 1 file changed, 506 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 30aa6551..d0046cc4 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1,10 +1,17 @@ import os import random -from typing import Any, Dict, Generator, List, Optional, Union +from abc import ABC, abstractmethod +from typing import Any, Dict, Generator, List, Optional, Tuple, Union import lightning as pl +import networkx as nx +import pandas as pd import torch import tqdm +from iterstrat.ml_stratifiers import ( + MultilabelStratifiedKFold, + MultilabelStratifiedShuffleSplit, +) from lightning.pytorch.core.datamodule import LightningDataModule from lightning_utilities.core.rank_zero import rank_zero_info from torch.utils.data import DataLoader @@ -583,3 +590,501 @@ def limits(self): Returns None, assuming no limits on data slicing. """ return None + + +class _DynamicDataset(XYBaseDataModule, ABC): + """ + A class for extracting and processing data from the given dataset. + + The processed and transformed data is stored in `data.pkl` and `data.pt` format as a whole respectively, + rather than as separate train, validation, and test splits, with dynamic splitting of data.pt occurring at runtime. + The `_DynamicDataset` class manages data splits by either generating them during execution or retrieving them from + a CSV file. + If no split file path is provided, `_generate_dynamic_splits` creates the training, validation, and test splits + from the encoded/transformed data, storing them in `_dynamic_df_train`, `_dynamic_df_val`, and `_dynamic_df_test`. + When a split file path is provided, `_retrieve_splits_from_csv` loads splits from the CSV file, which must + include 'id' and 'split' columns. + The `dynamic_split_dfs` property ensures that the necessary splits are loaded as required. + + Args: + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + **kwargs: Additional keyword arguments passed to XYBaseDataModule. + + Attributes: + dynamic_data_split_seed (int): The seed for random data splitting, default is 42. + splits_file_path (Optional[str]): Path to the CSV file containing split assignments. + """ + + # ---- Index for columns of processed `data.pkl` ------ + _ID_IDX: int = None + _DATA_REPRESENTATION_IDX: int = None + _LABELS_START_IDX: int = None + + def __init__( + self, + **kwargs, + ): + super(_DynamicDataset, self).__init__(**kwargs) + self.dynamic_data_split_seed = int(kwargs.get("seed", 42)) # default is 42 + # Class variables to store the dynamics splits + self._dynamic_df_train = None + self._dynamic_df_test = None + self._dynamic_df_val = None + # Path of csv file which contains a list of ids & their assignment to a dataset (either train, + # validation or test). + self.splits_file_path = self._validate_splits_file_path( + kwargs.get("splits_file_path", None) + ) + + @staticmethod + def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: + """ + Validates the file in provided splits file path. + + Args: + splits_file_path (Optional[str]): Path to the splits CSV file. + + Returns: + Optional[str]: Validated splits file path if checks pass, None if splits_file_path is None. + + Raises: + FileNotFoundError: If the splits file does not exist. + ValueError: If splits file is empty or missing required columns ('id' and/or 'split'), or not a CSV file. + """ + if splits_file_path is None: + return None + + if not os.path.isfile(splits_file_path): + raise FileNotFoundError(f"File {splits_file_path} does not exist") + + file_size = os.path.getsize(splits_file_path) + if file_size == 0: + raise ValueError(f"File {splits_file_path} is empty") + + # Check if the file has a CSV extension + if not splits_file_path.lower().endswith(".csv"): + raise ValueError(f"File {splits_file_path} is not a CSV file") + + # Read the first row of CSV file into a DataFrame + splits_df = pd.read_csv(splits_file_path, nrows=1) + + # Check if 'id' and 'split' columns are in the DataFrame + required_columns = {"id", "split"} + if not required_columns.issubset(splits_df.columns): + raise ValueError( + f"CSV file {splits_file_path} is missing required columns ('id' and/or 'split')." + ) + + return splits_file_path + + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Prepares the data for the dataset. + + This method checks for the presence of raw data in the specified directory. + If the raw data is missing, it fetches the ontology and creates a dataframe and saves it to a data.pkl file. + + The resulting dataframe/pickle file is expected to contain columns with the following structure: + - Column at index `self._ID_IDX`: ID of data instance + - Column at index `self._DATA_REPRESENTATION_IDX`: Sequence representation of the protein + - Column from index `self._LABELS_START_IDX` onwards: Labels + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + None + """ + print("Checking for processed data in", self.processed_dir_main) + + processed_name = self.processed_dir_main_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): + print("Missing processed data file (`data.pkl` file)") + os.makedirs(self.processed_dir_main, exist_ok=True) + data_path = self._download_required_data() + g = self._extract_class_hierarchy(data_path) + data_df = self._graph_to_raw_dataset(g) + self.save_processed(data_df, processed_name) + + @abstractmethod + def _download_required_data(self) -> str: + """ + Downloads the required raw data. + + Returns: + str: Path to the downloaded data. + """ + pass + + @abstractmethod + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts the class hierarchy from the data. + Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from + the term documents. + + Args: + data_path (str): Path to the data. + + Returns: + nx.DiGraph: The class hierarchy graph. + """ + pass + + @abstractmethod + def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + """ + Converts the graph to a raw dataset. + Uses the graph created by `_extract_class_hierarchy` method to extract the + raw data in Dataframe format with additional columns corresponding to each multi-label class. + + Args: + graph (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset. + """ + pass + + @abstractmethod + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + """ + Selects classes from the dataset based on a specified criteria. + + Args: + g (nx.Graph): The graph representing the dataset. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + List: A sorted list of node IDs that meet the specified criteria. + """ + pass + + def save_processed(self, data: pd.DataFrame, filename: str) -> None: + """ + Save the processed dataset to a pickle file. + + Args: + data (pd.DataFrame): The processed dataset to be saved. + filename (str): The filename for the pickle file. + """ + pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) + + # ------------------------------ Phase: Setup data ----------------------------------- + def setup_processed(self) -> None: + """ + Transforms `data.pkl` into a model input data format (`data.pt`), ensuring that the data is in a format + compatible for input to the model. + The transformed data contains the following keys: `ident`, `features`, `labels`, and `group`. + This method uses a subclass of Data Reader to perform the transformation. + + Returns: + None + """ + os.makedirs(self.processed_dir, exist_ok=True) + print("Missing transformed data (`data.pt` file). Transforming data.... ") + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + self.processed_dir_main_file_names_dict["data"], + ) + ), + os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), + ) + + @staticmethod + def _get_data_size(input_file_path: str) -> int: + """ + Get the size of the data from a pickled file. + + Args: + input_file_path (str): The path to the file. + + Returns: + int: The size of the data. + """ + with open(input_file_path, "rb") as f: + return len(pd.read_pickle(f)) + + @abstractmethod + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from given pickled file and yields individual dictionaries for each row. + + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Generator[Dict[str, Any], None, None]: Generator yielding dictionaries. + + """ + pass + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + @property + def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: + """ + Property to retrieve dynamic train, validation, and test splits. + + This property checks if dynamic data splits (`_dynamic_df_train`, `_dynamic_df_val`, `_dynamic_df_test`) + are already loaded. If any of them is None, it either generates them dynamically or retrieves them + from data file with help of pre-existing split csv file (`splits_file_path`) containing splits assignments. + + Returns: + dict: A dictionary containing the dynamic train, validation, and test DataFrames. + Keys are 'train', 'validation', and 'test'. + """ + if any( + split is None + for split in [ + self._dynamic_df_test, + self._dynamic_df_val, + self._dynamic_df_train, + ] + ): + if self.splits_file_path is None: + # Generate splits based on given seed, create csv file to records the splits + self._generate_dynamic_splits() + else: + # If user has provided splits file path, use it to get the splits from the data + self._retrieve_splits_from_csv() + return { + "train": self._dynamic_df_train, + "validation": self._dynamic_df_val, + "test": self._dynamic_df_test, + } + + def _generate_dynamic_splits(self) -> None: + """ + Generate data splits during runtime and save them in class variables. + + This method loads encoded data and generates train, validation, and test splits based on the loaded data. + """ + print("Generate dynamic splits...") + df_train, df_val, df_test = self._get_data_splits() + + # Generate splits.csv file to store ids of each corresponding split + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": df_train["ident"], "split": "train"}), + pd.DataFrame({"id": df_val["ident"], "split": "validation"}), + pd.DataFrame({"id": df_test["ident"], "split": "test"}), + ] + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + combined_split_assignment.to_csv( + os.path.join(self.processed_dir_main, "splits.csv") + ) + + # Store the splits in class variables + self.dynamic_df_train = df_train + self.dynamic_df_val = df_val + self.dynamic_df_test = df_test + + @abstractmethod + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Retrieve the train, validation, and test data splits for the dataset. + + This method returns data splits according to specific criteria implemented + in the subclasses. + + Returns: + tuple: A tuple containing DataFrames for train, validation, and test splits. + """ + pass + + def get_test_split( + self, df: pd.DataFrame, seed: Optional[int] = None + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Split the input DataFrame into training and testing sets based on multilabel stratified sampling. + + This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels + in the training and testing sets is approximately the same. The split is based on the "labels" column + in the DataFrame. + + Args: + df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column + named "labels" with the multilabel data. + seed (int, optional): The random seed to be used for reproducibility. Default is None. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames. + + Raises: + ValueError: If the DataFrame does not contain a column named "labels". + """ + print("\nGet test data split") + + labels_list = df["labels"].tolist() + + test_size = 1 - self.train_split - (1 - self.train_split) ** 2 + msss = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + + train_indices, test_indices = next(msss.split(labels_list, labels_list)) + + df_train = df.iloc[train_indices] + df_test = df.iloc[test_indices] + return df_train, df_test + + def get_train_val_splits_given_test( + self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None + ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: + """ + Split the dataset into train and validation sets, given a test set. + Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap + + Args: + df (pd.DataFrame): The original dataset. + test_df (pd.DataFrame): The test dataset. + seed (int, optional): The random seed to be used for reproducibility. Default is None. + + Returns: + Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and + validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train + and validation DataFrames. The keys are the names of the train and validation sets, and the values + are the corresponding DataFrames. + """ + print(f"Split dataset into train / val with given test set") + + test_ids = test_df["ident"].tolist() + df_trainval = df[~df["ident"].isin(test_ids)] + labels_list_trainval = df_trainval["labels"].tolist() + + if self.use_inner_cross_validation: + folds = {} + kfold = MultilabelStratifiedKFold( + n_splits=self.inner_k_folds, random_state=seed + ) + for fold, (train_ids, val_ids) in enumerate( + kfold.split( + labels_list_trainval, + labels_list_trainval, + ) + ): + df_validation = df_trainval.iloc[val_ids] + df_train = df_trainval.iloc[train_ids] + folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train + folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = ( + df_validation + ) + + return folds + + # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) + test_size = ((1 - self.train_split) ** 2) / self.train_split + msss = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + + train_indices, validation_indices = next( + msss.split(labels_list_trainval, labels_list_trainval) + ) + + df_validation = df_trainval.iloc[validation_indices] + df_train = df_trainval.iloc[train_indices] + return df_train, df_validation + + def _retrieve_splits_from_csv(self) -> None: + """ + Retrieve previously saved data splits from splits.csv file or from provided file path. + + This method loads the splits.csv file located at `self.splits_file_path`. + It then loads the encoded data (`data.pt`) and filters it based on the IDs retrieved from + splits.csv to reconstruct the train, validation, and test splits. + """ + print(f"Loading splits from {self.splits_file_path}...") + splits_df = pd.read_csv(self.splits_file_path) + + filename = self.processed_file_names_dict["data"] + data = torch.load(os.path.join(self.processed_dir, filename)) + df_data = pd.DataFrame(data) + + train_ids = splits_df[splits_df["split"] == "train"]["id"] + validation_ids = splits_df[splits_df["split"] == "validation"]["id"] + test_ids = splits_df[splits_df["split"] == "test"]["id"] + + self._dynamic_df_train = df_data[df_data["ident"].isin(train_ids)] + self._dynamic_df_val = df_data[df_data["ident"].isin(validation_ids)] + self._dynamic_df_test = df_data[df_data["ident"].isin(test_ids)] + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + @abstractmethod + def base_dir(self) -> str: + """ + Returns the base directory path for storing data. + + Returns: + str: The path to the base directory. + """ + pass + + @property + def processed_dir_main(self) -> str: + """ + Returns the main directory path where processed data is stored. + + Returns: + str: The path to the main processed data directory, based on the base directory and the instance's name. + """ + return os.path.join( + self.base_dir, + self._name, + "processed", + ) + + @property + def processed_dir(self) -> str: + """ + Returns the specific directory path for processed data, including identifiers. + + Returns: + str: The path to the processed data directory, including additional identifiers. + """ + return os.path.join( + self.processed_dir_main, + *self.identifier, + ) + + @property + def processed_dir_main_file_names_dict(self) -> dict: + """ + Returns a dictionary mapping processed data file names, processed by `prepare_data` method. + + Returns: + dict: A dictionary mapping dataset types to their respective processed file names. + For example, {"data": "data.pkl"}. + """ + return {"data": "data.pkl"} + + @property + def processed_file_names_dict(self) -> dict: + """ + Returns a dictionary mapping processed and transformed data file names to their final formats, which are + processed by `setup` method. + + Returns: + dict: A dictionary mapping dataset types to their respective final file names. + For example, {"data": "data.pt"}. + """ + return {"data": "data.pt"} + + @property + def processed_file_names(self) -> List[str]: + """ + Returns a list of file names for processed data. + + Returns: + List[str]: A list of file names corresponding to the processed data. + """ + return list(self.processed_file_names_dict.values()) From d182a22ae92de8782ef49f66c7f5f2e031faa6b2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 3 Aug 2024 00:29:51 +0200 Subject: [PATCH 12/30] update _ChEBIDataExtractor as per newly inherited _DynamicDataset base class --- chebai/preprocessing/datasets/base.py | 2 +- chebai/preprocessing/datasets/chebi.py | 822 ++++++------------------- 2 files changed, 187 insertions(+), 637 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index d0046cc4..fc665f3b 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -616,7 +616,7 @@ class _DynamicDataset(XYBaseDataModule, ABC): splits_file_path (Optional[str]): Path to the CSV file containing split assignments. """ - # ---- Index for columns of processed `data.pkl` ------ + # ---- Index for columns of processed `data.pkl` (should be derived from `_graph_to_raw_dataset` method) ------ _ID_IDX: int = None _DATA_REPRESENTATION_IDX: int = None _LABELS_START_IDX: int = None diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index a1bfcf6b..f04cab95 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -13,20 +13,16 @@ import pickle from abc import ABC from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple import fastobo import networkx as nx import pandas as pd import requests import torch -from iterstrat.ml_stratifiers import ( - MultilabelStratifiedKFold, - MultilabelStratifiedShuffleSplit, -) from chebai.preprocessing import reader as dr -from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset # exclude some entities from the dataset because the violate disjointness axioms CHEBI_BLACKLIST = [ @@ -109,7 +105,7 @@ class JCITokenData(JCIBase): READER = dr.ChemDataReader -class _ChEBIDataExtractor(XYBaseDataModule, ABC): +class _ChEBIDataExtractor(_DynamicDataset, ABC): """ A class for extracting and processing data from the ChEBI dataset. @@ -126,12 +122,18 @@ class _ChEBIDataExtractor(XYBaseDataModule, ABC): single_class (Optional[int]): The ID of the single class to predict. chebi_version_train (Optional[int]): The version of ChEBI to use for training and validation. dynamic_data_split_seed (int): The seed for random data splitting, default is 42. - dynamic_df_train (Optional[pd.DataFrame]): DataFrame to store the training data split. - dynamic_df_test (Optional[pd.DataFrame]): DataFrame to store the test data split. - dynamic_df_val (Optional[pd.DataFrame]): DataFrame to store the validation data split. splits_file_path (Optional[str]): Path to csv file containing split assignments. """ + # ---- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset` method) ------ + # "id" at row index 0 + # "name" at row index 1 + # "SMILES" at row index 2 + # labels starting from row index 3 + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 2 + _LABELS_START_IDX: int = 3 + def __init__( self, chebi_version_train: Optional[int] = None, @@ -144,11 +146,6 @@ def __init__( # use different version of chebi for training and validation (if not None) # (still uses self.chebi_version for test set) self.chebi_version_train = chebi_version_train - self.dynamic_data_split_seed = int(kwargs.get("seed", 42)) # default is 42 - # Class variables to store the dynamics splits - self.dynamic_df_train = None - self.dynamic_df_test = None - self.dynamic_df_val = None if self.chebi_version_train is not None: # Instantiate another same class with "chebi_version" as "chebi_version_train", if train_version is given @@ -159,64 +156,80 @@ def __init__( single_class=self.single_class, **_init_kwargs, ) - # Path of csv file which contains a list of chebi ids & their assignment to a dataset (either train, - # validation or test). - self.splits_file_path = self._validate_splits_file_path( - kwargs.get("splits_file_path", None) - ) - @staticmethod - def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: """ - Validates the file in provided splits file path. + Prepares the data for the Chebi dataset. + + This method checks for the presence of raw data in the specified directory. + If the raw data is missing, it fetches the ontology and creates a dataframe & saves it as data.pkl pickle file. + + The resulting dataframe/pickle file is expected to contain columns with the following structure: + - Column at index `self._ID_IDX`: ID of chebi data instance + - Column at index `self._DATA_REPRESENTATION_IDX`: SMILES representation of the chemical + - Column from index `self._LABELS_START_IDX` onwards: Labels Args: - splits_file_path (Optional[str]): Path to the splits CSV file. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. Returns: - Optional[str]: Validated splits file path if checks pass, None if splits_file_path is None. - - Raises: - FileNotFoundError: If the splits file does not exist. - ValueError: If the splits file is empty or missing required columns ('id' and/or 'split'), or not a CSV file. + None """ - if splits_file_path is None: - return None + super().prepare_data(args, kwargs) - if not os.path.isfile(splits_file_path): - raise FileNotFoundError(f"File {splits_file_path} does not exist") + if self.chebi_version_train is not None: + if not os.path.isfile( + os.path.join( + self._chebi_version_train_obj.processed_dir_main, + self._chebi_version_train_obj.raw_file_names_dict["data"], + ) + ): + print( + f"Missing processed data related to train version: {self.chebi_version_train}" + ) + print("Calling the prepare_data method related to it") + # Generate the "chebi_version_train" data if it doesn't exist + self._chebi_version_train_obj.prepare_data(*args, **kwargs) - file_size = os.path.getsize(splits_file_path) - if file_size == 0: - raise ValueError(f"File {splits_file_path} is empty") + def _download_required_data(self) -> str: + return self._load_chebi(self.chebi_version) - # Check if the file has a CSV extension - if not splits_file_path.lower().endswith(".csv"): - raise ValueError(f"File {splits_file_path} is not a CSV file") + def _load_chebi(self, version: int) -> str: + """ + Load the ChEBI ontology file. - # Read the first row of CSV file into a DataFrame - splits_df = pd.read_csv(splits_file_path, nrows=1) + Args: + version (int): The version of the ChEBI ontology to load. - # Check if 'id' and 'split' columns are in the DataFrame - required_columns = {"id", "split"} - if not required_columns.issubset(splits_df.columns): - raise ValueError( - f"CSV file {splits_file_path} is missing required columns ('id' and/or 'split')." + Returns: + str: The file path of the loaded ChEBI ontology. + """ + chebi_name = ( + f"chebi.obo" if version == self.chebi_version else f"chebi_v{version}.obo" + ) + chebi_path = os.path.join(self.raw_dir, chebi_name) + if not os.path.isfile(chebi_path): + print( + f"Missing raw chebi data related to version: v_{version}, Downloading..." ) + url = f"http://purl.obolibrary.org/obo/chebi/{version}/chebi.obo" + r = requests.get(url, allow_redirects=True) + open(chebi_path, "wb").write(r.content) + return chebi_path - return splits_file_path - - def extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: """ Extracts the class hierarchy from the ChEBI ontology. Args: - chebi_path (str): The path to the ChEBI ontology. + data_path (str): The path to the ChEBI ontology. Returns: nx.DiGraph: The class hierarchy. """ - with open(chebi_path, encoding="utf-8") as chebi: + with open(data_path, encoding="utf-8") as chebi: chebi = "\n".join(l for l in chebi if not l.startswith("xref:")) elements = [ term_callback(clause) @@ -230,20 +243,13 @@ def extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: print("Compute transitive closure") return nx.transitive_closure_dag(g) - def select_classes(self, g, split_name, *args, **kwargs): - - raise NotImplementedError - - def graph_to_raw_dataset( - self, g: nx.DiGraph, split_name: Optional[str] = None - ) -> pd.DataFrame: + def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: """ Preparation step before creating splits, uses graph created by extract_class_hierarchy(), split_name is only relevant, if a separate train_version is set. Args: g (nx.DiGraph): The class hierarchy graph. - split_name (Optional[str], optional): Name of the split. Defaults to None. Returns: pd.DataFrame: The raw dataset created from the graph. @@ -260,10 +266,14 @@ def graph_to_raw_dataset( if smiles ) ) - data = OrderedDict(id=molecules) - data["name"] = [names.get(node) for node in molecules] - data["SMILES"] = smiles_list - for n in self.select_classes(g, split_name): + data = OrderedDict(id=molecules) # `id` column at index 0 + data["name"] = [ + names.get(node) for node in molecules + ] # `name` column at index 1 + data["SMILES"] = smiles_list # `SMILES` (data representation) column at index 2 + + # Labels columns from index 3 onwards + for n in self.select_classes(g): data[n] = [ ((n in g.predecessors(node)) or (n == node)) for node in molecules ] @@ -273,134 +283,10 @@ def graph_to_raw_dataset( data = data[[name not in CHEBI_BLACKLIST for name, _ in data.iterrows()]] # This filters the DataFrame to include only the rows where at least one value in the row from 4th column # onwards is True/non-zero. - data = data[data.iloc[:, 3:].any(axis=1)] + data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)] return data - def save_raw(self, data: pd.DataFrame, filename: str) -> None: - """ - Save the raw dataset to a pickle file. - - Args: - data (pd.DataFrame): The raw dataset to be saved. - filename (str): The filename for the pickle file. - """ - pd.to_pickle(data, open(os.path.join(self.raw_dir, filename), "wb")) - - def save_processed(self, data: pd.DataFrame, filename: str) -> None: - """ - Save the processed dataset to a pickle file. - - Args: - data (pd.DataFrame): The processed dataset to be saved. - filename (str): The filename for the pickle file. - """ - pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) - - def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: - """ - Loads a dictionary from a pickled file, yielding individual dictionaries for each row. - - This method reads data from a specified pickled file, processes each row to extract relevant - information, and yields dictionaries containing the keys `features`, `labels`, and `ident`. - If `single_class` is specified, it only includes the label for that specific class; otherwise, - it includes labels for all classes starting from the fourth column. - - The pickled file is expected to contain rows with the following structure: - - Data at row index 0: ID of the chebi data instance - - Data at row index 2: SMILES representation for the chemical - - Data from row index 3 onwards: Labels - - This method is used in `_load_data_from_file` to process each row of data and convert it - into the desired dictionary format before loading it into the model. - - Args: - input_file_path (str): The path to the input pickled file. - - Yields: - Dict[str, Any]: A dictionary with keys `features`, `labels`, and `ident`. - `features` contains the sequence, `labels` contains the labels as boolean values, - and `ident` contains the identifier. - """ - with open(input_file_path, "rb") as input_file: - df = pd.read_pickle(input_file) - if self.single_class is not None: - single_cls_index = list(df.columns).index(int(self.single_class)) - for row in df.values: - if self.single_class is None: - labels = row[3:].astype(bool) - else: - labels = [bool(row[single_cls_index])] - yield dict(features=row[2], labels=labels, ident=row[0]) - - @staticmethod - def _get_data_size(input_file_path: str) -> int: - """ - Get the size of the data from a pickled file. - - Args: - input_file_path (str): The path to the file. - - Returns: - int: The size of the data. - """ - with open(input_file_path, "rb") as f: - return len(pd.read_pickle(f)) - - def _setup_pruned_test_set( - self, df_test_chebi_version: pd.DataFrame - ) -> pd.DataFrame: - """ - Create a test set with the same leaf nodes, but use only classes that appear in the training set. - - Args: - df_test_chebi_version (pd.DataFrame): The test dataset. - - Returns: - pd.DataFrame: The pruned test dataset. - """ - # TODO: find a more efficient way to do this - filename_old = "classes.txt" - # filename_new = f"classes_v{self.chebi_version_train}.txt" - # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) - - # Load original classes (from the current ChEBI version - chebi_version) - with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: - orig_classes = file.readlines() - - # Load new classes (from the training ChEBI version - chebi_version_train) - with open( - os.path.join( - self._chebi_version_train_obj.processed_dir_main, filename_old - ), - "r", - ) as file: - new_classes = file.readlines() - - # Create a mapping which give index of a class from chebi_version, if the corresponding - # class exists in chebi_version_train, Size = Number of classes in chebi_version - mapping = [ - None if or_class not in new_classes else new_classes.index(or_class) - for or_class in orig_classes - ] - - # Iterate over each data instance in the test set which is derived from chebi_version - for _, row in df_test_chebi_version.iterrows(): - # Size = Number of classes in chebi_version_train - new_labels = [False for _ in new_classes] - for ind, label in enumerate(row["labels"]): - # If the chebi_version class exists in the chebi_version_train and has a True label, - # set the corresponding label in new_labels to True - if mapping[ind] is not None and label: - new_labels[mapping[ind]] = label - # Update the labels from test instance from chebi_version to the new labels, which are compatible to both versions - row["labels"] = new_labels - - # torch.save( - # chebi_ver_test_data, - # os.path.join(self.processed_dir, self.processed_file_names_dict["test"]), - # ) - return df_test_chebi_version - + # ------------------------------ Phase: Setup data ----------------------------------- def setup_processed(self) -> None: """ Transform and prepare processed data for the ChEBI dataset. @@ -419,60 +305,7 @@ def setup_processed(self) -> None: classes that appear in the training set. """ - print("Transform data") - os.makedirs(self.processed_dir, exist_ok=True) - # -------- Commented the code for Data Handling Restructure for Issue No.10 - # -------- https://github.com/ChEB-AI/python-chebai/issues/10 - # for k in self.processed_file_names_dict.keys(): - # processed_name = ( - # "test.pt" if k == "test" else self.processed_file_names_dict[k] - # ) - # if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - # print("transform", k) - # torch.save( - # self._load_data_from_file( - # os.path.join(self.raw_dir, self.raw_file_names_dict[k]) - # ), - # os.path.join(self.processed_dir, processed_name), - # ) - # # create second test set with classes used in train - # if self.chebi_version_train is not None and not os.path.isfile( - # os.path.join(self.processed_dir, self.processed_file_names_dict["test"]) - # ): - # print("transform test (select classes)") - # self._setup_pruned_test_set() - # - # processed_name = self.processed_file_names_dict[k] - # if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - # print( - # "Missing encoded data, transform processed data into encoded data", - # k, - # ) - # torch.save( - # self._load_data_from_file( - # os.path.join( - # self.processed_dir_main, self.raw_file_names_dict[k] - # ) - # ), - # os.path.join(self.processed_dir, processed_name), - # ) - - # Transform the processed data into encoded data - processed_name = self.processed_file_names_dict["data"] - if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - print( - f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:", - processed_name, - ) - torch.save( - self._load_data_from_file( - os.path.join( - self.processed_dir_main, - self.raw_file_names_dict["data"], - ) - ), - os.path.join(self.processed_dir, processed_name), - ) + super().setup_processed() # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist if self.chebi_version_train is not None and not os.path.isfile( @@ -484,338 +317,60 @@ def setup_processed(self) -> None: print( f"Missing encoded data related to train version: {self.chebi_version_train}" ) - print("Call the setup method related to it") + print("Calling the setup method related to it") self._chebi_version_train_obj.setup() - def get_test_split( - self, df: pd.DataFrame, seed: Optional[int] = None - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """ - Split the input DataFrame into training and testing sets based on multilabel stratified sampling. - - This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels - in the training and testing sets is approximately the same. The split is based on the "labels" column - in the DataFrame. - - Args: - df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column - named "labels" with the multilabel data. - seed (int, optional): The random seed to be used for reproducibility. Default is None. - - Returns: - Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames. - - Raises: - ValueError: If the DataFrame does not contain a column named "labels". + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: """ - print("\nGet test data split") - - labels_list = df["labels"].tolist() - - test_size = 1 - self.train_split - (1 - self.train_split) ** 2 - msss = MultilabelStratifiedShuffleSplit( - n_splits=1, test_size=test_size, random_state=seed - ) + Loads a dictionary from a pickled file, yielding individual dictionaries for each row. - train_indices, test_indices = next(msss.split(labels_list, labels_list)) + This method reads data from a specified pickled file, processes each row to extract relevant + information, and yields dictionaries containing the keys `features`, `labels`, and `ident`. + If `single_class` is specified, it only includes the label for that specific class; otherwise, + it includes labels for all classes starting from the fourth column. - df_train = df.iloc[train_indices] - df_test = df.iloc[test_indices] - return df_train, df_test + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of the chebi data instance + - Data at row index `self._DATA_REPRESENTATION_IDX` : SMILES representation for the chemical + - Data from row index `self._LABELS_START_IDX` onwards: Labels - def get_train_val_splits_given_test( - self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None - ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: - """ - Split the dataset into train and validation sets, given a test set. - Use test set (e.g., loaded from another chebi version or generated in get_test_split), to avoid overlap + This method is used in `_load_data_from_file` to process each row of data and convert it + into the desired dictionary format before loading it into the model. Args: - df (pd.DataFrame): The original dataset. - test_df (pd.DataFrame): The test dataset. - seed (int, optional): The random seed to be used for reproducibility. Default is None. + input_file_path (str): The path to the input pickled file. - Returns: - Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and - validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train - and validation DataFrames. The keys are the names of the train and validation sets, and the values - are the corresponding DataFrames. + Yields: + Dict[str, Any]: A dictionary with keys `features`, `labels`, and `ident`. + `features` contains the sequence, `labels` contains the labels as boolean values, + and `ident` contains the identifier. """ - print(f"Split dataset into train / val with given test set") - - test_ids = test_df["ident"].tolist() - # ---- list comprehension degrades performance, dataframe operations are faster - # mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]] - # df_trainval = df_trainval[mask] - df_trainval = df[~df["ident"].isin(test_ids)] - labels_list_trainval = df_trainval["labels"].tolist() - - if self.use_inner_cross_validation: - folds = {} - kfold = MultilabelStratifiedKFold( - n_splits=self.inner_k_folds, random_state=seed - ) - for fold, (train_ids, val_ids) in enumerate( - kfold.split( - labels_list_trainval, - labels_list_trainval, - ) - ): - df_validation = df_trainval.iloc[val_ids] - df_train = df_trainval.iloc[train_ids] - folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train - folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = ( - df_validation + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + if self.single_class is not None: + single_cls_index = list(df.columns).index(int(self.single_class)) + for row in df.values: + if self.single_class is None: + labels = row[self._LABELS_START_IDX :].astype(bool) + else: + labels = [bool(row[single_cls_index])] + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + labels=labels, + ident=row[self._ID_IDX], ) - return folds - - # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) - test_size = ((1 - self.train_split) ** 2) / self.train_split - msss = MultilabelStratifiedShuffleSplit( - n_splits=1, test_size=test_size, random_state=seed - ) - - train_indices, validation_indices = next( - msss.split(labels_list_trainval, labels_list_trainval) - ) - - df_validation = df_trainval.iloc[validation_indices] - df_train = df_trainval.iloc[train_indices] - return df_train, df_validation - - @property - def processed_dir_main(self) -> str: - """ - Return the main directory path for processed data. - - Returns: - str: The path to the main processed data directory. - """ - return os.path.join( - self.base_dir, - self._name, - "processed", - ) - - @property - def processed_dir(self) -> str: - """ - Return the directory path for processed data. - - Returns: - str: The path to the processed data directory. - """ - res = os.path.join( - self.processed_dir_main, - *self.identifier, - ) - if self.single_class is None: - return res - else: - return os.path.join(res, f"single_{self.single_class}") - - @property - def base_dir(self) -> str: - """ - Return the base directory path for data. - - Returns: - str: The base directory path for data. - """ - return os.path.join("data", f"chebi_v{self.chebi_version}") - - @property - def processed_file_names_dict(self) -> dict: - """ - Return a dictionary of processed file names. - - Returns: - dict: A dictionary where keys are file names and values are paths. - """ - train_v_str = ( - f"_v{self.chebi_version_train}" if self.chebi_version_train else "" - ) - # res = {"test": f"test{train_v_str}.pt"} - res = {} - - for set in ["train", "validation"]: - # TODO: code will be modified into CV issue for dynamic splits - if self.use_inner_cross_validation: - for i in range(self.inner_k_folds): - res[f"fold_{i}_{set}"] = os.path.join( - self.fold_dir, f"fold_{i}_{set}{train_v_str}.pt" - ) - # else: - # res[set] = f"{set}{train_v_str}.pt" - res["data"] = "data.pt" - return res - - @property - def raw_file_names_dict(self) -> dict: - """ - Return a dictionary of raw file names. - - Returns: - dict: A dictionary where keys are file names and values are paths. - """ - train_v_str = ( - f"_v{self.chebi_version_train}" if self.chebi_version_train else "" - ) - # res = { - # "test": f"test.pkl" - # } # no extra raw test version for chebi_version_train - use default test set and only - # adapt processed file - res = {} - for set in ["train", "validation"]: - # TODO: code will be modified into CV issue for dynamic splits - if self.use_inner_cross_validation: - for i in range(self.inner_k_folds): - res[f"fold_{i}_{set}"] = os.path.join( - self.fold_dir, f"fold_{i}_{set}{train_v_str}.pkl" - ) - # else: - # res[set] = f"{set}{train_v_str}.pkl" - res["data"] = "data.pkl" - return res - - @property - def processed_file_names(self) -> List[str]: - """ - Return a list of processed file names. - - Returns: - List[str]: A list containing processed file names. - """ - return list(self.processed_file_names_dict.values()) - - @property - def raw_file_names(self) -> List[str]: - """ - Return a list of raw file names. - - Returns: - List[str]: A list containing raw file names. - """ - return list(self.raw_file_names_dict.values()) - - def _load_chebi(self, version: int) -> str: - """ - Load the ChEBI ontology file. - - Args: - version (int): The version of the ChEBI ontology to load. - - Returns: - str: The file path of the loaded ChEBI ontology. - """ - chebi_name = ( - f"chebi.obo" if version == self.chebi_version else f"chebi_v{version}.obo" - ) - chebi_path = os.path.join(self.raw_dir, chebi_name) - if not os.path.isfile(chebi_path): - print(f"Load ChEBI ontology (v_{version})") - url = f"http://purl.obolibrary.org/obo/chebi/{version}/chebi.obo" - r = requests.get(url, allow_redirects=True) - open(chebi_path, "wb").write(r.content) - return chebi_path - - def prepare_data(self, *args: Any, **kwargs: Any) -> None: - """ - Prepares the data for the Chebi dataset. - - This method checks for the presence of raw data in the specified directory. - If the raw data is missing, it fetches the ontology and creates a dataframe & saves it as data.pkl pickle file. - - The resulting dataframe/pickle file is expected to contain columns with the following structure: - - Column at index 0: ID of chebi data instance - - Column at index 2: SMILES representation of the chemical - - Column from index 3 onwards: Labels - - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - None + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ - print("Check for processed data in", self.processed_dir_main) - if any( - not os.path.isfile(os.path.join(self.processed_dir_main, f)) - for f in self.raw_file_names - ): - os.makedirs(self.processed_dir_main, exist_ok=True) - print("Missing raw data. Go fetch...") - - # -------- Commented the code for Data Handling Restructure for Issue No.10 - # -------- https://github.com/ChEB-AI/python-chebai/issues/10 - # missing test set -> create - # if not os.path.isfile( - # os.path.join(self.raw_dir, self.raw_file_names_dict["test"]) - # ): - # chebi_path = self._load_chebi(self.chebi_version) - # g = self.extract_class_hierarchy(chebi_path) - # df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["test"]) - # _, test_df = self.get_test_split(df) - # self.save_raw(test_df, self.raw_file_names_dict["test"]) - # # load test_split from file - # else: - # with open( - # os.path.join(self.raw_dir, self.raw_file_names_dict["test"]), "rb" - # ) as input_file: - # test_df = pickle.load(input_file) - # # create train/val split based on test set - # chebi_path = self._load_chebi( - # self.chebi_version_train - # if self.chebi_version_train is not None - # else self.chebi_version - # ) - # g = self.extract_class_hierarchy(chebi_path) - # if self.use_inner_cross_validation: - # df = self.graph_to_raw_dataset( - # g, self.raw_file_names_dict[f"fold_0_train"] - # ) - # else: - # df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["train"]) - # train_val_dict = self.get_train_val_splits_given_test(df, test_df) - # for name, df in train_val_dict.items(): - # self.save_raw(df, name) - - # Data from chebi_version - chebi_path = self._load_chebi(self.chebi_version) - g = self.extract_class_hierarchy(chebi_path) - df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["data"]) - self.save_processed(df, filename=self.raw_file_names_dict["data"]) - - if self.chebi_version_train is not None: - if not os.path.isfile( - os.path.join( - self._chebi_version_train_obj.processed_dir_main, - self._chebi_version_train_obj.raw_file_names_dict["data"], - ) - ): - print( - f"Missing processed data related to train version: {self.chebi_version_train}" - ) - print("Call the prepare_data method related to it") - # Generate the "chebi_version_train" data if it doesn't exist - self._chebi_version_train_obj.prepare_data(*args, **kwargs) - - def _generate_dynamic_splits(self) -> None: - """ - Generate data splits during runtime and save them in class variables. - This method loads encoded data derived from either `chebi_version` or `chebi_version_train` and generates train, validation, and test splits based on the loaded data. If `chebi_version_train` is specified, the test set is pruned to include only labels that exist in `chebi_version_train`. - Raises: - FileNotFoundError: If the required data file (`data.pt`) for either `chebi_version` or `chebi_version_train` - does not exist. It advises calling `prepare_data` or `setup` methods to generate - the dataset files. + Returns: + """ - print("Generate dynamic splits...") # Load encoded data derived from "chebi_version" try: filename = self.processed_file_names_dict["data"] @@ -868,84 +423,62 @@ def _generate_dynamic_splits(self) -> None: ) df_test = df_test_chebi_ver - # Generate splits.csv file to store ids of each corresponding split - split_assignment_list: List[pd.DataFrame] = [ - pd.DataFrame({"id": df_train["ident"], "split": "train"}), - pd.DataFrame({"id": df_val["ident"], "split": "validation"}), - pd.DataFrame({"id": df_test["ident"], "split": "test"}), - ] - combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) - combined_split_assignment.to_csv( - os.path.join(self.processed_dir_main, "splits.csv") - ) - - # Store the splits in class variables - self.dynamic_df_train = df_train - self.dynamic_df_val = df_val - self.dynamic_df_test = df_test + return df_train, df_val, df_test - def _retrieve_splits_from_csv(self) -> None: + def _setup_pruned_test_set( + self, df_test_chebi_version: pd.DataFrame + ) -> pd.DataFrame: """ - Retrieve previously saved data splits from splits.csv file or from provided file path. + Create a test set with the same leaf nodes, but use only classes that appear in the training set. + + Args: + df_test_chebi_version (pd.DataFrame): The test dataset. - This method loads the splits.csv file located at `self.splits_file_path`. - It then loads the encoded data (`data.pt`) derived from `chebi_version` and filters - it based on the IDs retrieved from splits.csv to reconstruct the train, validation, - and test splits. + Returns: + pd.DataFrame: The pruned test dataset. """ - print(f"Loading splits from {self.splits_file_path}...") - splits_df = pd.read_csv(self.splits_file_path) + # TODO: find a more efficient way to do this + filename_old = "classes.txt" + # filename_new = f"classes_v{self.chebi_version_train}.txt" + # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) - filename = self.processed_file_names_dict["data"] - data_chebi_version = torch.load(os.path.join(self.processed_dir, filename)) - df_chebi_version = pd.DataFrame(data_chebi_version) + # Load original classes (from the current ChEBI version - chebi_version) + with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: + orig_classes = file.readlines() - train_ids = splits_df[splits_df["split"] == "train"]["id"] - validation_ids = splits_df[splits_df["split"] == "validation"]["id"] - test_ids = splits_df[splits_df["split"] == "test"]["id"] + # Load new classes (from the training ChEBI version - chebi_version_train) + with open( + os.path.join( + self._chebi_version_train_obj.processed_dir_main, filename_old + ), + "r", + ) as file: + new_classes = file.readlines() - self.dynamic_df_train = df_chebi_version[ - df_chebi_version["ident"].isin(train_ids) - ] - self.dynamic_df_val = df_chebi_version[ - df_chebi_version["ident"].isin(validation_ids) - ] - self.dynamic_df_test = df_chebi_version[ - df_chebi_version["ident"].isin(test_ids) + # Create a mapping which give index of a class from chebi_version, if the corresponding + # class exists in chebi_version_train, Size = Number of classes in chebi_version + mapping = [ + None if or_class not in new_classes else new_classes.index(or_class) + for or_class in orig_classes ] - @property - def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: - """ - Property to retrieve dynamic train, validation, and test splits. - - This property checks if dynamic data splits (`dynamic_df_train`, `dynamic_df_val`, `dynamic_df_test`) - are already loaded. If any of them is None, it either generates them dynamically or retrieves them - from data file with help of pre-existing Split csv file (`splits_file_path`) containing splits assignments. + # Iterate over each data instance in the test set which is derived from chebi_version + for _, row in df_test_chebi_version.iterrows(): + # Size = Number of classes in chebi_version_train + new_labels = [False for _ in new_classes] + for ind, label in enumerate(row["labels"]): + # If the chebi_version class exists in the chebi_version_train and has a True label, + # set the corresponding label in new_labels to True + if mapping[ind] is not None and label: + new_labels[mapping[ind]] = label + # Update the labels from test instance from chebi_version to the new labels, which are compatible to both versions + row["labels"] = new_labels - Returns: - dict: A dictionary containing the dynamic train, validation, and test DataFrames. - Keys are 'train', 'validation', and 'test'. - """ - if any( - split is None - for split in [ - self.dynamic_df_test, - self.dynamic_df_val, - self.dynamic_df_train, - ] - ): - if self.splits_file_path is None: - # Generate splits based on given seed, create csv file to records the splits - self._generate_dynamic_splits() - else: - # If user has provided splits file path, use it to get the splits from the data - self._retrieve_splits_from_csv() - return { - "train": self.dynamic_df_train, - "validation": self.dynamic_df_val, - "test": self.dynamic_df_test, - } + # torch.save( + # chebi_ver_test_data, + # os.path.join(self.processed_dir, self.processed_file_names_dict["test"]), + # ) + return df_test_chebi_version def load_processed_data( self, kind: Optional[str] = None, filename: Optional[str] = None @@ -988,10 +521,36 @@ def load_processed_data( except FileNotFoundError: raise FileNotFoundError(f"File {filename} doesn't exist") + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def base_dir(self) -> str: + """ + Return the base directory path for data. + + Returns: + str: The base directory path for data. + """ + return os.path.join("data", f"chebi_v{self.chebi_version}") + + @property + def processed_dir(self) -> str: + """ + Return the directory path for processed data. + + Returns: + str: The path to the processed data directory. + """ + res = os.path.join( + self.processed_dir_main, + *self.identifier, + ) + if self.single_class is None: + return res + else: + return os.path.join(res, f"single_{self.single_class}") + class JCIExtendedBase(_ChEBIDataExtractor): - LABEL_INDEX = 3 - SMILES_INDEX = 2 @property def label_number(self): @@ -1018,8 +577,6 @@ class ChEBIOverX(_ChEBIDataExtractor): THRESHOLD (None): The threshold for selecting classes. """ - LABEL_INDEX: int = 3 - SMILES_INDEX: int = 2 READER: dr.ChemDataReader = dr.ChemDataReader THRESHOLD: int = None @@ -1043,7 +600,7 @@ def _name(self) -> str: """ return f"ChEBI{self.THRESHOLD}" - def select_classes(self, g: nx.DiGraph, split_name: str, *args, **kwargs) -> List: + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: """ Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold. @@ -1052,7 +609,6 @@ def select_classes(self, g: nx.DiGraph, split_name: str, *args, **kwargs) -> Lis Args: g (nx.Graph): The graph representing the dataset. - split_name (str) : Name of the split (not used). *args: Additional positional arguments (not used). **kwargs: Additional keyword arguments (not used). @@ -1080,12 +636,6 @@ def select_classes(self, g: nx.DiGraph, split_name: str, *args, **kwargs) -> Lis ) ) filename = "classes.txt" - # if ( - # self.chebi_version_train - # is not None - # # and self.raw_file_names_dict["test"] != split_name - # ): - # filename = f"classes_v{self.chebi_version_train}.txt" with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: fout.writelines(str(node) + "\n" for node in nodes) return nodes From 25a9594299578221c62c364e20283ede63f0b522 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 3 Aug 2024 01:04:47 +0200 Subject: [PATCH 13/30] update _GOUniprotDataExtractor to inherit _DynamicDataset --- chebai/preprocessing/datasets/chebi.py | 54 +- chebai/preprocessing/datasets/go_uniprot.py | 822 ++++++-------------- 2 files changed, 267 insertions(+), 609 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index f04cab95..9b8866e6 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -170,6 +170,8 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: - Column at index `self._DATA_REPRESENTATION_IDX`: SMILES representation of the chemical - Column from index `self._LABELS_START_IDX` onwards: Labels + It will pre-process the data related to `chebi_version_train`, if specified. + Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. @@ -194,6 +196,12 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: self._chebi_version_train_obj.prepare_data(*args, **kwargs) def _download_required_data(self) -> str: + """ + Downloads the required raw data related to chebi. + + Returns: + str: Path to the downloaded data. + """ return self._load_chebi(self.chebi_version) def _load_chebi(self, version: int) -> str: @@ -222,6 +230,8 @@ def _load_chebi(self, version: int) -> str: def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: """ Extracts the class hierarchy from the ChEBI ontology. + Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from + the chebi term documents from `.obo` file. Args: data_path (str): The path to the ChEBI ontology. @@ -245,8 +255,9 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: """ - Preparation step before creating splits, uses graph created by extract_class_hierarchy(), - split_name is only relevant, if a separate train_version is set. + Converts the graph to a raw dataset. + Uses the graph created by `_extract_class_hierarchy` method to extract the + raw data in Dataframe format with additional columns corresponding to each multi-label class. Args: g (nx.DiGraph): The class hierarchy graph. @@ -296,14 +307,7 @@ def setup_processed(self) -> None: The transformed data must contain the following keys: `ident`, `features`, `labels`, and `group`. This method uses a subclass of Data Reader to perform the transformation. - This method sets up the processed data directories and files based on the ChEBI version - and train version (if specified). It ensures that the required processed data files exist - by loading raw data, transforming it into processed format, and saving it. - - It also handles special cases, such as generating a pruned test set if `chebi_version_train` - is specified and the test set does not already exist. This pruned test set includes only - classes that appear in the training set. - + It will transform the data related to `chebi_version_train`, if specified. """ super().setup_processed() @@ -363,15 +367,29 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No # ------------------------------ Phase: Dynamic Splits ----------------------------------- def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ - This method loads encoded data derived from either `chebi_version` or `chebi_version_train` - and generates train, validation, and test splits based on the loaded data. - If `chebi_version_train` is specified, the test set is pruned to include only labels that - exist in `chebi_version_train`. + Loads encoded/transformed data and generates training, validation, and test splits. - Returns: + This method first loads encoded data from a file named `data.pt`, which is derived from either + `chebi_version` or `chebi_version_train`. It then splits the data into training, validation, and test sets. + + If `chebi_version_train` is provided: + - Loads additional encoded data from `chebi_version_train`. + - Splits this data into training and validation sets, while using the test set from `chebi_version`. + - Prunes the test set from `chebi_version` to include only labels that exist in `chebi_version_train`. + + If `chebi_version_train` is not provided: + - Splits the data from `chebi_version` into training, validation, and test sets without modification. + Raises: + FileNotFoundError: If the required `data.pt` file(s) do not exist. Ensure that `prepare_data` + and/or `setup` methods have been called to generate the dataset files. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: + - Training set + - Validation set + - Test set """ - # Load encoded data derived from "chebi_version" try: filename = self.processed_file_names_dict["data"] data_chebi_version = torch.load(os.path.join(self.processed_dir, filename)) @@ -474,10 +492,6 @@ def _setup_pruned_test_set( # Update the labels from test instance from chebi_version to the new labels, which are compatible to both versions row["labels"] = new_labels - # torch.save( - # chebi_ver_test_data, - # os.path.join(self.processed_dir, self.processed_file_names_dict["test"]), - # ) return df_test_chebi_version def load_processed_data( diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index b5022514..15fd40e0 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from tempfile import NamedTemporaryFile -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple import fastobo import networkx as nx @@ -23,16 +23,12 @@ import requests import torch from Bio import SwissProt -from iterstrat.ml_stratifiers import ( - MultilabelStratifiedKFold, - MultilabelStratifiedShuffleSplit, -) from chebai.preprocessing import reader as dr -from chebai.preprocessing.datasets import XYBaseDataModule +from chebai.preprocessing.datasets.base import _DynamicDataset -class _GOUniprotDataExtractor(XYBaseDataModule, ABC): +class _GOUniprotDataExtractor(_DynamicDataset, ABC): """ A class for extracting and processing data from the Gene Ontology (GO) dataset and the Swiss UniProt dataset. @@ -47,15 +43,16 @@ class _GOUniprotDataExtractor(XYBaseDataModule, ABC): """ _GO_DATA_INIT = "GO" - # ---- Index for columns of processed `data.pkl` ------ + + # ---- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset` method) ------ # "id" at row index 0 # "name" at row index 1 # "sequence" at row index 2 # "swiss_ident" at row index 3 # labels starting from row index 4 - _LABELS_STARTING_INDEX: int = 4 - _SEQUENCE_INDEX: int = 2 - _ID_INDEX = 0 + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 2 + _LABELS_START_IDX: int = 4 _GO_DATA_URL = "http://purl.obolibrary.org/obo/go/go-basic.obo" _SWISS_DATA_URL = "https://ftp.uniprot.org/pub/databases/uniprot/knowledgebase/complete/uniprot_sprot.dat.gz" @@ -65,492 +62,112 @@ def __init__( **kwargs, ): super(_GOUniprotDataExtractor, self).__init__(**kwargs) - self.dynamic_data_split_seed = int(kwargs.get("seed", 42)) # default is 42 - # Class variables to store the dynamics splits - self._dynamic_df_train = None - self._dynamic_df_test = None - self._dynamic_df_val = None - # Path of csv file which contains a list of go ids & their assignment to a dataset (either train, - # validation or test). - self.splits_file_path = self._validate_splits_file_path( - kwargs.get("splits_file_path", None) - ) - @staticmethod - def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: + # ------------------------------ Phase: Prepare data ----------------------------------- + def _download_required_data(self) -> str: """ - Validates the file in provided splits file path. - - Args: - splits_file_path (Optional[str]): Path to the splits CSV file. + Downloads the required raw data related to Gene Ontology (GO) and Swiss-UniProt dataset. Returns: - Optional[str]: Validated splits file path if checks pass, None if splits_file_path is None. - - Raises: - FileNotFoundError: If the splits file does not exist. - ValueError: If the splits file is empty or missing required columns ('id' and/or 'split'), or not a CSV file. + str: Path to the downloaded data. """ - if splits_file_path is None: - return None - - if not os.path.isfile(splits_file_path): - raise FileNotFoundError(f"File {splits_file_path} does not exist") - - file_size = os.path.getsize(splits_file_path) - if file_size == 0: - raise ValueError(f"File {splits_file_path} is empty") + self._download_swiss_uni_prot_data() + return self._download_gene_ontology_data() - # Check if the file has a CSV extension - if not splits_file_path.lower().endswith(".csv"): - raise ValueError(f"File {splits_file_path} is not a CSV file") - - # Read the first row of CSV file into a DataFrame - splits_df = pd.read_csv(splits_file_path, nrows=1) - - # Check if 'id' and 'split' columns are in the DataFrame - required_columns = {"id", "split"} - if not required_columns.issubset(splits_df.columns): - raise ValueError( - f"CSV file {splits_file_path} is missing required columns ('id' and/or 'split')." - ) - - return splits_file_path - - @property - def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: + def _download_gene_ontology_data(self) -> str: """ - Property to retrieve dynamic train, validation, and test splits. + Download the Gene Ontology data `.obo` file. - This property checks if dynamic data splits (`_dynamic_df_train`, `_dynamic_df_val`, `_dynamic_df_test`) - are already loaded. If any of them is None, it either generates them dynamically or retrieves them - from data file with help of pre-existing Split csv file (`splits_file_path`) containing splits assignments. + Note: + Quote from : https://geneontology.org/docs/download-ontology/ + Three versions of the ontology are available, the one use in this method is described below: + http://purl.obolibrary.org/obo/go/go-basic.obo + The basic version of the GO, filtered such that the graph is guaranteed to be acyclic and annotations + can be propagated up the graph. The relations included are `is a, part of, regulates, negatively` + `regulates` and `positively regulates`. This version excludes relationships that cross the 3 GO + hierarchies. This version should be used with most GO-based annotation tools. Returns: - dict: A dictionary containing the dynamic train, validation, and test DataFrames. - Keys are 'train', 'validation', and 'test'. - """ - if any( - split is None - for split in [ - self._dynamic_df_test, - self._dynamic_df_val, - self._dynamic_df_train, - ] - ): - if self.splits_file_path is None: - # Generate splits based on given seed, create csv file to records the splits - self._generate_dynamic_splits() - else: - # If user has provided splits file path, use it to get the splits from the data - self._retrieve_splits_from_csv() - return { - "train": self._dynamic_df_train, - "validation": self._dynamic_df_val, - "test": self._dynamic_df_test, - } - - def _generate_dynamic_splits(self) -> None: - """ - Generate data splits during runtime and save them in class variables. - - This method loads encoded data generates train, validation, and test splits based on the loaded data. - - Raises: - FileNotFoundError: If the required data file (`data.pt`) does not exist. It advises calling `prepare_data` - or `setup` methods to generate the dataset files. - """ - print("Generate dynamic splits...") - # Load encoded data - try: - filename = self.processed_file_names_dict["data"] - data_go = torch.load(os.path.join(self.processed_dir, filename)) - except FileNotFoundError: - raise FileNotFoundError( - f"File data.pt doesn't exists. " - f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" - ) - - df_go_data = pd.DataFrame(data_go) - train_df_go, df_test = self.get_test_split( - df_go_data, seed=self.dynamic_data_split_seed - ) - - # Get all splits - df_train, df_val = self.get_train_val_splits_given_test( - train_df_go, - df_test, - seed=self.dynamic_data_split_seed, - ) - - # Generate splits.csv file to store ids of each corresponding split - split_assignment_list: List[pd.DataFrame] = [ - pd.DataFrame({"id": df_train["ident"], "split": "train"}), - pd.DataFrame({"id": df_val["ident"], "split": "validation"}), - pd.DataFrame({"id": df_test["ident"], "split": "test"}), - ] - combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) - combined_split_assignment.to_csv( - os.path.join(self.processed_dir_main, "splits.csv") - ) - - # Store the splits in class variables - self._dynamic_df_train = df_train - self._dynamic_df_val = df_val - self._dynamic_df_test = df_test - - def _retrieve_splits_from_csv(self) -> None: - """ - Retrieve previously saved data splits from splits.csv file or from provided file path. - - This method loads the splits.csv file located at `self.splits_file_path`. - It then loads the encoded data (`data.pt`) and filters it based on the IDs retrieved from - splits.csv to reconstruct the train, validation, and test splits. - """ - print(f"Loading splits from {self.splits_file_path}...") - splits_df = pd.read_csv(self.splits_file_path) - - filename = self.processed_file_names_dict["data"] - data_go = torch.load(os.path.join(self.processed_dir, filename)) - df_go_data = pd.DataFrame(data_go) - - train_ids = splits_df[splits_df["split"] == "train"]["id"] - validation_ids = splits_df[splits_df["split"] == "validation"]["id"] - test_ids = splits_df[splits_df["split"] == "test"]["id"] - - self._dynamic_df_train = df_go_data[df_go_data["ident"].isin(train_ids)] - self._dynamic_df_val = df_go_data[df_go_data["ident"].isin(validation_ids)] - self._dynamic_df_test = df_go_data[df_go_data["ident"].isin(test_ids)] - - def get_test_split( - self, df: pd.DataFrame, seed: Optional[int] = None - ) -> Tuple[pd.DataFrame, pd.DataFrame]: + str: The file path of the loaded Gene Ontology data. """ - Split the input DataFrame into training and testing sets based on multilabel stratified sampling. - - This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels - in the training and testing sets is approximately the same. The split is based on the "labels" column - in the DataFrame. - - Args: - df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column - named "labels" with the multilabel data. - seed (int, optional): The random seed to be used for reproducibility. Default is None. + go_path = os.path.join(self.raw_dir, self.raw_file_names_dict["GO"]) + os.makedirs(os.path.dirname(go_path), exist_ok=True) - Returns: - Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames. + if not os.path.isfile(go_path): + print("Missing Gene Ontology raw data") + print(f"Downloading Gene Ontology data....") + r = requests.get(self._GO_DATA_URL, allow_redirects=True) + r.raise_for_status() # Check if the request was successful + open(go_path, "wb").write(r.content) + return go_path - Raises: - ValueError: If the DataFrame does not contain a column named "labels". + def _download_swiss_uni_prot_data(self) -> Optional[str]: """ - print("\nGet test data split") - - labels_list = df["labels"].tolist() - - test_size = 1 - self.train_split - (1 - self.train_split) ** 2 - msss = MultilabelStratifiedShuffleSplit( - n_splits=1, test_size=test_size, random_state=seed - ) - - train_indices, test_indices = next(msss.split(labels_list, labels_list)) - - df_train = df.iloc[train_indices] - df_test = df.iloc[test_indices] - return df_train, df_test + Download the Swiss-Prot data file from UniProt Knowledgebase. - def get_train_val_splits_given_test( - self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None - ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: - """ - Split the dataset into train and validation sets, given a test set. - Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap + Note: + UniProt Knowledgebase is collection of functional information on proteins, with accurate, consistent + and rich annotation. - Args: - df (pd.DataFrame): The original dataset. - test_df (pd.DataFrame): The test dataset. - seed (int, optional): The random seed to be used for reproducibility. Default is None. + Swiss-Prot contains manually-annotated records with information extracted from literature and + curator-evaluated computational analysis. Returns: - Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and - validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train - and validation DataFrames. The keys are the names of the train and validation sets, and the values - are the corresponding DataFrames. - """ - print(f"Split dataset into train / val with given test set") - - test_ids = test_df["ident"].tolist() - df_trainval = df[~df["ident"].isin(test_ids)] - labels_list_trainval = df_trainval["labels"].tolist() - - if self.use_inner_cross_validation: - folds = {} - kfold = MultilabelStratifiedKFold( - n_splits=self.inner_k_folds, random_state=seed - ) - for fold, (train_ids, val_ids) in enumerate( - kfold.split( - labels_list_trainval, - labels_list_trainval, - ) - ): - df_validation = df_trainval.iloc[val_ids] - df_train = df_trainval.iloc[train_ids] - folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train - folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = ( - df_validation - ) - - return folds - - # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) - test_size = ((1 - self.train_split) ** 2) / self.train_split - msss = MultilabelStratifiedShuffleSplit( - n_splits=1, test_size=test_size, random_state=seed - ) - - train_indices, validation_indices = next( - msss.split(labels_list_trainval, labels_list_trainval) - ) - - df_validation = df_trainval.iloc[validation_indices] - df_train = df_trainval.iloc[train_indices] - return df_train, df_validation - - def setup_processed(self) -> None: + str: The file path of the loaded Swiss-Prot data file. """ - Transforms `data.pkl` into a model input data format (`data.pt`), ensuring that the data is in a format - compatible for input to the model. - The transformed data contains the following keys: `ident`, `features`, `labels`, and `group`. - This method uses a subclass of Data Reader to perform the transformation. - - Returns: - None - """ - print("Transform data") - os.makedirs(self.processed_dir, exist_ok=True) - print("Missing transformed `data.pt` file. Transforming data.... ") - torch.save( - self._load_data_from_file( - os.path.join( - self.processed_dir_main, - self.processed_dir_main_file_names_dict["data"], - ) - ), - os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), + uni_prot_file_path = os.path.join( + self.raw_dir, self.raw_file_names_dict["SwissUniProt"] ) + os.makedirs(os.path.dirname(uni_prot_file_path), exist_ok=True) - def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: - """ - Loads data from a pickled file and yields individual dictionaries for each row. - - The pickled file is expected to contain rows with the following structure: - - Data at row index 0: ID of go data instance - - Data at row index 2: Sequence representation of protein - - Data from row index 4 onwards: Labels - - This method is used by `_load_data_from_file` to generate dictionaries that are then - processed and converted into a list of dictionaries containing the features and labels. - - Args: - input_file_path (str): The path to the pickled input file. - - Yields: - Dict[str, Any]: A dictionary containing: - - `features` (str): The sequence data from the file. - - `labels` (np.ndarray): A boolean array of labels starting from row index 4. - - `ident` (Any): The identifier from row index 0. - """ - with open(input_file_path, "rb") as input_file: - df = pd.read_pickle(input_file) - # "id" at row index 0 - # "name" at row index 1 - # "sequence" at row index 2 - # "swiss_ident" at row index 3 - # labels starting from row index 4 - for row in df.values: - labels = row[self._LABELS_STARTING_INDEX :].astype(bool) - # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group - # "group" set to None, by default as no such entity for this data - yield dict( - features=row[self._SEQUENCE_INDEX], - labels=labels, - ident=row[self._ID_INDEX], - ) - - def prepare_data(self, *args: Any, **kwargs: Any) -> None: - """ - Prepares the data for the Go dataset. - - This method checks for the presence of raw data in the specified directory. - If the raw data is missing, it fetches the ontology and creates a dataframe and saves it to a data.pkl file. - - The resulting dataframe/pickle file is expected to contain columns with the following structure: - - Column at index 0: ID of go data instance - - Column at index 2: Sequence representation of the protein - - Column from index 4 onwards: Labels - - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - None - """ - print("Checking for processed data in", self.processed_dir_main) - - processed_name = self.processed_dir_main_file_names_dict["data"] - if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): - print("Missing Gene Ontology processed data (`data.pkl` file)") - os.makedirs(self.processed_dir_main, exist_ok=True) - # swiss_path = self._download_swiss_uni_prot_data() - self._download_swiss_uni_prot_data() - go_path = self._download_gene_ontology_data() - g = self._extract_go_class_hierarchy(go_path) - data_df = self._graph_to_raw_dataset(g) - self.save_processed(data_df, processed_name) - - @abstractmethod - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: - """ - Selects classes from the GO dataset based on a specified criteria. - - Args: - g (nx.Graph): The graph representing the dataset. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - List: A sorted list of node IDs that meet the specified criteria. - - """ - raise NotImplementedError - - def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: - """ - Preparation step before creating splits, - uses the graph created by _extract_go_class_hierarchy() to extract the - raw data in Dataframe format with extra columns corresponding to each multi-label class. - - Data Format: pd.DataFrame - - Column 0 : ID (Identifier for GO data instance) - - Column 1 : Name of the protein - - Column 2 : Sequence representation of the protein - - Column 3 : Unique identifier of the protein from swiss dataset. - - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating where the - data instance belong to this class or not. - Args: - g (nx.DiGraph): The class hierarchy graph. - - Returns: - pd.DataFrame: The raw dataset created from the graph. - """ - sequences = nx.get_node_attributes(g, "sequence") - names = nx.get_node_attributes(g, "name") - swiss_idents = nx.get_node_attributes(g, "swiss_ident") - - print(f"Processing graph") - - # Gets list of node ids, names, sequences, swiss identifier where sequence is not empty/None. - data_list = [] - for node_id, sequence in sequences.items(): - if sequence: - data_list.append( - ( - node_id, - names.get(node_id), - sequence, - swiss_idents.get(node_id), - ) - ) - - node_ids, names_list, sequences_list, swiss_identifier_list = zip(*data_list) - data = OrderedDict(id=node_ids) - - data["name"] = names_list - data["sequence"] = sequences_list - data["swiss_ident"] = swiss_identifier_list - - # Assuming select_classes is implemented and returns a list of class IDs - for n in self.select_classes(g): - data[n] = [ - ((n in g.predecessors(node)) or (n == node)) for node in node_ids - ] + if not os.path.isfile(uni_prot_file_path): + print(f"Downloading Swiss UniProt data....") - data = pd.DataFrame(data) - # This filters the DataFrame to include only the rows where at least one value in the row from 5th column - # onwards is True/non-zero. - data = data[data.iloc[:, self._LABELS_STARTING_INDEX :].any(axis=1)] - return data + # Create a temporary file + with NamedTemporaryFile(delete=False) as tf: + temp_filename = tf.name + print(f"Downloading to temporary file {temp_filename}") - def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: - """ - Parses the Swiss-Prot data and returns a mapping from Gene Ontology (GO) data ID to Swiss-Prot ID - along with the sequence representation of the protein. + # Download the file + response = requests.get(self._SWISS_DATA_URL, stream=True) + with open(temp_filename, "wb") as temp_file: + shutil.copyfileobj(response.raw, temp_file) - This mapping is necessary because the GO data does not include the protein sequence representation. + print(f"Downloaded to {temp_filename}") - Returns: - Dict[int, Dict[str, str]]: A dictionary where the keys are GO data IDs (int) and the values are - dictionaries containing: - - "sequence" (str): The protein sequence. - - "swiss_ident" (str): The unique identifier for each Swiss-Prot record. - """ - # # https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt - # --------- --------------------------- ------------------------------ - # Line code Content Occurrence in an entry - # --------- --------------------------- ------------------------------ - # ID Identifier (keyword) Once; starts a keyword entry - # IC Identifier (category) Once; starts a category entry - # AC Accession (KW-xxxx) Once - # DE Definition Once or more - # SY Synonyms Optional; once or more - # GO Gene ontology (GO) mapping Optional; once or more - # HI Hierarchy Optional; once or more - # WW Relevant WWW site Optional; once or more - # CA Category Once per keyword entry; - # absent in category entries - # // Terminator Once; ends an entry - # --------------------------------------------------------------------------- - print("Parsing swiss uniprot raw data....") + # Unpack the gzipped file + try: + print(f"Unzipping the file....") + with gzip.open(temp_filename, "rb") as f_in: + output_file_path = uni_prot_file_path + with open(output_file_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"Unpacked and saved to {output_file_path}") - swiss_go_mapping = {} - swiss_data = SwissProt.parse( - open( - os.path.join(self.raw_dir, self.raw_file_names_dict["SwissUniProt"]), - "r", - ) - ) + except Exception as e: + print(f"Failed to unpack the file: {e}") + finally: + # Clean up the temporary file + os.remove(temp_filename) + print(f"Removed temporary file {temp_filename}") - for record in swiss_data: - if record.data_class != "Reviewed": - # To consider only manually-annotated swiss data - continue - # Cross-reference has mapping for each protein to each type of data set - for cross_ref in record.cross_references: - if cross_ref[0] == self._GO_DATA_INIT: - # Only consider cross-reference related to GO dataset - go_id = _GOUniprotDataExtractor._parse_go_id(cross_ref[1]) - swiss_go_mapping[go_id] = { - "sequence": record.sequence, - "swiss_ident": record.entry_name, # Unique identifier for each swiss data record - } - return swiss_go_mapping + return uni_prot_file_path - def _extract_go_class_hierarchy(self, go_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: """ Extracts the class hierarchy from the GO ontology. Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with GO term data and corresponding Swiss-Prot data (obtained via `_get_go_swiss_data_mapping`). Args: - go_path (str): The path to the GO ontology. + data_path (str): The path to the GO ontology. Returns: nx.DiGraph: A directed graph representing the class hierarchy, where nodes are GO terms and edges represent parent-child relationships. """ elements = [] - for term in fastobo.load(go_path): + for term in fastobo.load(data_path): if isinstance(term, fastobo.typedef.TypedefFrame): # ---- To avoid term frame of the below format/structure ---- # [Typedef] @@ -635,188 +252,215 @@ def _parse_go_id(go_id: str) -> int: # is_a: GO:0009968 ! negative regulation of signal transduction return int(str(go_id).split(":")[1].split("!")[0].strip()) - def _download_gene_ontology_data(self) -> str: + def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: """ - Download the Gene Ontology data `.obo` file. + Parses the Swiss-Prot data and returns a mapping from Gene Ontology (GO) data ID to Swiss-Prot ID + along with the sequence representation of the protein. - Note: - Quote from : https://geneontology.org/docs/download-ontology/ - Three versions of the ontology are available, the one use in this method is described below: - http://purl.obolibrary.org/obo/go/go-basic.obo - The basic version of the GO, filtered such that the graph is guaranteed to be acyclic and annotations - can be propagated up the graph. The relations included are `is a, part of, regulates, negatively` - `regulates` and `positively regulates`. This version excludes relationships that cross the 3 GO - hierarchies. This version should be used with most GO-based annotation tools. + This mapping is necessary because the GO data does not include the protein sequence representation. Returns: - str: The file path of the loaded Gene Ontology data. + Dict[int, Dict[str, str]]: A dictionary where the keys are GO data IDs (int) and the values are + dictionaries containing: + - "sequence" (str): The protein sequence. + - "swiss_ident" (str): The unique identifier for each Swiss-Prot record. """ - go_path = os.path.join(self.raw_dir, self.raw_file_names_dict["GO"]) - os.makedirs(os.path.dirname(go_path), exist_ok=True) + # # https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt + # --------- --------------------------- ------------------------------ + # Line code Content Occurrence in an entry + # --------- --------------------------- ------------------------------ + # ID Identifier (keyword) Once; starts a keyword entry + # IC Identifier (category) Once; starts a category entry + # AC Accession (KW-xxxx) Once + # DE Definition Once or more + # SY Synonyms Optional; once or more + # GO Gene ontology (GO) mapping Optional; once or more + # HI Hierarchy Optional; once or more + # WW Relevant WWW site Optional; once or more + # CA Category Once per keyword entry; + # absent in category entries + # // Terminator Once; ends an entry + # --------------------------------------------------------------------------- + print("Parsing swiss uniprot raw data....") - if not os.path.isfile(go_path): - print("Missing Gene Ontology raw data") - print(f"Downloading Gene Ontology data....") - r = requests.get(self._GO_DATA_URL, allow_redirects=True) - r.raise_for_status() # Check if the request was successful - open(go_path, "wb").write(r.content) - return go_path + swiss_go_mapping = {} + swiss_data = SwissProt.parse( + open( + os.path.join(self.raw_dir, self.raw_file_names_dict["SwissUniProt"]), + "r", + ) + ) - def _download_swiss_uni_prot_data(self) -> Optional[str]: - """ - Download the Swiss-Prot data file from UniProt Knowledgebase. + for record in swiss_data: + if record.data_class != "Reviewed": + # To consider only manually-annotated swiss data + continue + # Cross-reference has mapping for each protein to each type of data set + for cross_ref in record.cross_references: + if cross_ref[0] == self._GO_DATA_INIT: + # Only consider cross-reference related to GO dataset + go_id = _GOUniprotDataExtractor._parse_go_id(cross_ref[1]) + swiss_go_mapping[go_id] = { + "sequence": record.sequence, + "swiss_ident": record.entry_name, # Unique identifier for each swiss data record + } + return swiss_go_mapping - Note: - UniProt Knowledgebase is collection of functional information on proteins, with accurate, consistent - and rich annotation. + def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + """ + Preparation step before creating splits, + uses the graph created by _extract_go_class_hierarchy() to extract the + raw data in Dataframe format with extra columns corresponding to each multi-label class. - Swiss-Prot contains manually-annotated records with information extracted from literature and - curator-evaluated computational analysis. + Data Format: pd.DataFrame + - Column 0 : ID (Identifier for GO data instance) + - Column 1 : Name of the protein + - Column 2 : Sequence representation of the protein + - Column 3 : Unique identifier of the protein from swiss dataset. + - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating where the + data instance belong to this class or not. + Args: + g (nx.DiGraph): The class hierarchy graph. Returns: - str: The file path of the loaded Swiss-Prot data file. + pd.DataFrame: The raw dataset created from the graph. """ - uni_prot_file_path = os.path.join( - self.raw_dir, self.raw_file_names_dict["SwissUniProt"] - ) - os.makedirs(os.path.dirname(uni_prot_file_path), exist_ok=True) - - if not os.path.isfile(uni_prot_file_path): - print(f"Downloading Swiss UniProt data....") + sequences = nx.get_node_attributes(g, "sequence") + names = nx.get_node_attributes(g, "name") + swiss_idents = nx.get_node_attributes(g, "swiss_ident") - # Create a temporary file - with NamedTemporaryFile(delete=False) as tf: - temp_filename = tf.name - print(f"Downloading to temporary file {temp_filename}") + print(f"Processing graph") - # Download the file - response = requests.get(self._SWISS_DATA_URL, stream=True) - with open(temp_filename, "wb") as temp_file: - shutil.copyfileobj(response.raw, temp_file) + # Gets list of node ids, names, sequences, swiss identifier where sequence is not empty/None. + data_list = [] + for node_id, sequence in sequences.items(): + if sequence: + data_list.append( + ( + node_id, + names.get(node_id), + sequence, + swiss_idents.get(node_id), + ) + ) - print(f"Downloaded to {temp_filename}") + node_ids, names_list, sequences_list, swiss_identifier_list = zip(*data_list) + data = OrderedDict(id=node_ids) # ID column at index 0 - # Unpack the gzipped file - try: - print(f"Unzipping the file....") - with gzip.open(temp_filename, "rb") as f_in: - output_file_path = uni_prot_file_path - with open(output_file_path, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - print(f"Unpacked and saved to {output_file_path}") + data["name"] = names_list # Name column at index 1 + data["sequence"] = ( + sequences_list # Sequence (data representation) column at index 2 + ) + data["swiss_ident"] = swiss_identifier_list # Swiss_ident column at index 3 - except Exception as e: - print(f"Failed to unpack the file: {e}") - finally: - # Clean up the temporary file - os.remove(temp_filename) - print(f"Removed temporary file {temp_filename}") + # Assuming select_classes is implemented and returns a list of class IDs + for n in self.select_classes(g): + data[n] = [ + ((n in g.predecessors(node)) or (n == node)) for node in node_ids + ] - return uni_prot_file_path + data = pd.DataFrame(data) + # This filters the DataFrame to include only the rows where at least one value in the row from 5th column + # onwards is True/non-zero. + data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data - def save_processed(self, data: pd.DataFrame, filename: str) -> None: + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: """ - Save the processed dataset to a pickle file. + Loads data from a pickled file and yields individual dictionaries for each row. - Args: - data (pd.DataFrame): The processed dataset to be saved. - filename (str): The filename for the pickle file. - """ - pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels - @staticmethod - def _get_data_size(input_file_path: str) -> int: - """ - Get the size of the data from a pickled file. + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. Args: - input_file_path (str): The path to the file. + input_file_path (str): The path to the pickled input file. - Returns: - int: The size of the data. + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (str): The sequence data from the file. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. """ - with open(input_file_path, "rb") as f: - return len(pd.read_pickle(f)) + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group + # "group" set to None, by default as no such entity for this data + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + labels=labels, + ident=row[self._ID_IDX], + ) - @property - def raw_file_names_dict(self) -> dict: + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ - Returns a dictionary of raw file names used in data processing. + Loads encoded data and generates training, validation, and test splits. - Returns: - dict: A dictionary mapping dataset names to their respective file names. - For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}. - """ - return {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"} + This method attempts to load encoded data from a file named `data.pt`. It then splits this data into + training, validation, and test sets. - @property - def base_dir(self) -> str: - """ - Returns the base directory path for storing GO-Uniprot data. + Raises: + FileNotFoundError: If the `data.pt` file does not exist. Ensure that `prepare_data` and/or + `setup` methods are called to generate the necessary dataset files. Returns: - str: The path to the base directory, which is "data/GO_UniProt". - """ - return os.path.join("data", f"GO_UniProt") - - @property - def processed_dir_main(self) -> str: + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: + - Training set + - Validation set + - Test set """ - Returns the main directory path where processed data is stored. + try: + filename = self.processed_file_names_dict["data"] + data_go = torch.load(os.path.join(self.processed_dir, filename)) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists. " + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) - Returns: - str: The path to the main processed data directory, based on the base directory and the instance's name. - """ - return os.path.join( - self.base_dir, - self._name, - "processed", + df_go_data = pd.DataFrame(data_go) + train_df_go, df_test = self.get_test_split( + df_go_data, seed=self.dynamic_data_split_seed ) - @property - def processed_dir(self) -> str: - """ - Returns the specific directory path for processed data, including identifiers. - - Returns: - str: The path to the processed data directory, including additional identifiers. - """ - return os.path.join( - self.processed_dir_main, - *self.identifier, + # Get all splits + df_train, df_val = self.get_train_val_splits_given_test( + train_df_go, + df_test, + seed=self.dynamic_data_split_seed, ) - @property - def processed_dir_main_file_names_dict(self) -> dict: - """ - Returns a dictionary mapping processed data file names. - - Returns: - dict: A dictionary mapping dataset types to their respective processed file names. - For example, {"data": "data.pkl"}. - """ - return {"data": "data.pkl"} + return df_train, df_val, df_test + # ------------------------------ Phase: Raw Properties ----------------------------------- @property - def processed_file_names_dict(self) -> dict: + def base_dir(self) -> str: """ - Returns a dictionary mapping processed data file names to their final formats. + Returns the base directory path for storing GO-Uniprot data. Returns: - dict: A dictionary mapping dataset types to their respective final file names. - For example, {"data": "data.pt"}. + str: The path to the base directory, which is "data/GO_UniProt". """ - return {"data": "data.pt"} + return os.path.join("data", f"GO_UniProt") @property - def processed_file_names(self) -> List[str]: + def raw_file_names_dict(self) -> dict: """ - Returns a list of file names for processed data. + Returns a dictionary of raw file names used in data processing. Returns: - List[str]: A list of file names corresponding to the processed data. + dict: A dictionary mapping dataset names to their respective file names. + For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}. """ - return list(self.processed_file_names_dict.values()) + return {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"} class _GoUniProtOverX(_GOUniprotDataExtractor, ABC): From 5a4860d03eda7c28617ba987464f92f1fe5edb3a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 10 Aug 2024 13:01:20 +0200 Subject: [PATCH 14/30] add load_processed_data to base --- chebai/preprocessing/datasets/base.py | 61 ++++++++++++++++++++++++-- chebai/preprocessing/datasets/chebi.py | 47 ++------------------ 2 files changed, 61 insertions(+), 47 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index fc665f3b..f8d3892c 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -877,15 +877,16 @@ def _generate_dynamic_splits(self) -> None: pd.DataFrame({"id": df_val["ident"], "split": "validation"}), pd.DataFrame({"id": df_test["ident"], "split": "test"}), ] + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) combined_split_assignment.to_csv( - os.path.join(self.processed_dir_main, "splits.csv") + os.path.join(self.processed_dir_main, "splits.csv"), index=False ) # Store the splits in class variables - self.dynamic_df_train = df_train - self.dynamic_df_val = df_val - self.dynamic_df_test = df_test + self._dynamic_df_train = df_train + self._dynamic_df_val = df_val + self._dynamic_df_test = df_test @abstractmethod def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: @@ -1017,6 +1018,58 @@ def _retrieve_splits_from_csv(self) -> None: self._dynamic_df_val = df_data[df_data["ident"].isin(validation_ids)] self._dynamic_df_test = df_data[df_data["ident"].isin(test_ids)] + def load_processed_data( + self, kind: Optional[str] = None, filename: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Loads processed data from a specified dataset type or file. + + This method retrieves processed data based on the dataset type (`kind`) such as "train", + "val", or "test", or directly from a provided filename. When `kind` is specified, the method + leverages the `dynamic_split_dfs` property to dynamically generate or retrieve the corresponding + data splits if they are not already loaded. If both `kind` and `filename` are provided, `filename` + takes precedence. + + Args: + kind (str, optional): The type of dataset to load ("train", "val", or "test"). + If `filename` is provided, this argument is ignored. Defaults to None. + filename (str, optional): The name of the file to load the dataset from. + If provided, this takes precedence over `kind`. Defaults to None. + + Returns: + List[Dict[str, Any]]: A list of dictionaries, where each dictionary contains + the processed data for an individual data point. + + Raises: + ValueError: If both `kind` and `filename` are None, as one of them is required to load the dataset. + KeyError: If the specified `kind` does not exist in the `dynamic_split_dfs` property or + `processed_file_names_dict`, when expected. + FileNotFoundError: If the file corresponding to the provided `filename` does not exist. + """ + if kind is None and filename is None: + raise ValueError( + "Either kind or filename is required to load the correct dataset, both are None" + ) + + # If both kind and filename are given, use filename + if kind is not None and filename is None: + try: + if self.use_inner_cross_validation and kind != "test": + filename = self.processed_file_names_dict[ + f"fold_{self.fold_index}_{kind}" + ] + else: + data_df = self.dynamic_split_dfs[kind] + return data_df.to_dict(orient="records") + except KeyError: + kind = f"{kind}" + + # If filename is provided + try: + return torch.load(os.path.join(self.processed_dir, filename)) + except FileNotFoundError: + raise FileNotFoundError(f"File {filename} doesn't exist") + # ------------------------------ Phase: Raw Properties ----------------------------------- @property @abstractmethod diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 9b8866e6..c17347b4 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -185,7 +185,9 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: if not os.path.isfile( os.path.join( self._chebi_version_train_obj.processed_dir_main, - self._chebi_version_train_obj.raw_file_names_dict["data"], + self._chebi_version_train_obj.processed_dir_main_file_names_dict[ + "data" + ], ) ): print( @@ -315,7 +317,7 @@ def setup_processed(self) -> None: if self.chebi_version_train is not None and not os.path.isfile( os.path.join( self._chebi_version_train_obj.processed_dir, - self._chebi_version_train_obj.raw_file_names_dict["data"], + self._chebi_version_train_obj.processed_file_names_dict["data"], ) ): print( @@ -494,47 +496,6 @@ def _setup_pruned_test_set( return df_test_chebi_version - def load_processed_data( - self, kind: Optional[str] = None, filename: Optional[str] = None - ) -> List[Dict[str, Any]]: - """ - Load processed data from a file. - - Args: - kind (str, optional): The kind of dataset to load such as "train", "val", or "test". Defaults to None. - filename (str, optional): The name of the file to load the dataset from. Defaults to None. - - Returns: - List[Dict[str, Any]] : The loaded processed data. - - Raises: - KeyError: If specified kind key doesn't exist. - FileNotFoundError: If the specified file does not exist. - """ - if kind is None and filename is None: - raise ValueError( - "Either kind or filename is required to load the correct dataset, both are None" - ) - - # If both kind and filename are given, use filename - if kind is not None and filename is None: - try: - if self.use_inner_cross_validation and kind != "test": - filename = self.processed_file_names_dict[ - f"fold_{self.fold_index}_{kind}" - ] - else: - data_df = self.dynamic_split_dfs[kind] - return data_df.to_dict(orient="records") - except KeyError: - kind = f"{kind}" - - # If filename is provided - try: - return torch.load(os.path.join(self.processed_dir, filename)) - except FileNotFoundError: - raise FileNotFoundError(f"File {filename} doesn't exist") - # ------------------------------ Phase: Raw Properties ----------------------------------- @property def base_dir(self) -> str: From 53daf97471dbcd4af17b60bce671a97b0af7b0ef Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 13 Aug 2024 14:53:01 +0200 Subject: [PATCH 15/30] go data: changes - logic to select go data branch based on given input - update class hierarchy and raw data logic --- chebai/preprocessing/datasets/go_uniprot.py | 222 +++++++++++--------- 1 file changed, 121 insertions(+), 101 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 15fd40e0..db8c6791 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -5,7 +5,8 @@ # https://doi.org/10.1093/bioinformatics/btx624 # Github: https://github.com/bio-ontology-research-group/deepgo # https://www.ebi.ac.uk/GOA/downloads - +# https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt +# https://www.uniprot.org/uniprotkb __all__ = ["GoUniProtOver100", "GoUniProtOver50"] @@ -15,7 +16,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from tempfile import NamedTemporaryFile -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple, Union import fastobo import networkx as nx @@ -43,26 +44,46 @@ class _GOUniprotDataExtractor(_DynamicDataset, ABC): """ _GO_DATA_INIT = "GO" + _SWISS_DATA_INIT = "SWISS" # ---- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset` method) ------ - # "id" at row index 0 - # "name" at row index 1 - # "sequence" at row index 2 - # "swiss_ident" at row index 3 + # "swiss_id" at row index 0 + # "accession" at row index 1 + # "go_ids" at row index 2 + # "sequence" at row index 3 # labels starting from row index 4 _ID_IDX: int = 0 - _DATA_REPRESENTATION_IDX: int = 2 + _DATA_REPRESENTATION_IDX: int = 3 # here `sequence` column _LABELS_START_IDX: int = 4 - _GO_DATA_URL = "http://purl.obolibrary.org/obo/go/go-basic.obo" - _SWISS_DATA_URL = "https://ftp.uniprot.org/pub/databases/uniprot/knowledgebase/complete/uniprot_sprot.dat.gz" - - def __init__( - self, - **kwargs, - ): + _GO_DATA_URL: str = "https://purl.obolibrary.org/obo/go/go-basic.obo" + _SWISS_DATA_URL: str = ( + "https://ftp.uniprot.org/pub/databases/uniprot/knowledgebase/complete/uniprot_sprot.dat.gz" + ) + + # Gene Ontology (GO) has three major branches, one for biological processes (BP), molecular functions (MF) and + # cellular components (CC). The value "all" will take data related to all three branches into account. + _ALL_GO_BRANCHES: str = "all" + _GO_BRANCH_NAMESPACE: Dict[str, str] = { + "BP": "biological_process", + "MF": "molecular_function", + "CC": "cellular_component", + } + + def __init__(self, **kwargs): + self.go_branch: str = self._get_go_branch(**kwargs) super(_GOUniprotDataExtractor, self).__init__(**kwargs) + @classmethod + def _get_go_branch(cls, **kwargs) -> str: + go_branch_value = kwargs.get("go_branch", cls._ALL_GO_BRANCHES) + allowed_values = list(cls._GO_BRANCH_NAMESPACE.keys()) + [cls._ALL_GO_BRANCHES] + if go_branch_value not in allowed_values: + raise ValueError( + f"Invalid value for go_branch: {go_branch_value}, Allowed values: {allowed_values}" + ) + return go_branch_value + # ------------------------------ Phase: Prepare data ----------------------------------- def _download_required_data(self) -> str: """ @@ -81,7 +102,7 @@ def _download_gene_ontology_data(self) -> str: Note: Quote from : https://geneontology.org/docs/download-ontology/ Three versions of the ontology are available, the one use in this method is described below: - http://purl.obolibrary.org/obo/go/go-basic.obo + https://purl.obolibrary.org/obo/go/go-basic.obo The basic version of the GO, filtered such that the graph is guaranteed to be acyclic and annotations can be propagated up the graph. The relations included are `is a, part of, regulates, negatively` `regulates` and `positively regulates`. This version excludes relationships that cross the 3 GO @@ -166,6 +187,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: nx.DiGraph: A directed graph representing the class hierarchy, where nodes are GO terms and edges represent parent-child relationships. """ + print("Extracting class hierarchy...") elements = [] for term in fastobo.load(data_path): if isinstance(term, fastobo.typedef.TypedefFrame): @@ -188,22 +210,27 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: if term_dict: elements.append(term_dict) - go_to_swiss_mapping = self._get_go_swiss_data_mapping() - g = nx.DiGraph() + + # Add GO term nodes to the graph and their hierarchical ontology for n in elements: - # Swiss data is mapped to respective go data instance - node_mapping_dict = go_to_swiss_mapping.get(n["id"], {}) - # Combine the dictionaries for node attributes - node_attributes = {**n, **node_mapping_dict} - g.add_node(n["id"], **node_attributes) - g.add_edges_from([(p, q["id"]) for q in elements for p in q["parents"]]) + g.add_node(n["go_id"], node_type=f"{self._GO_DATA_INIT}", **n) + g.add_edges_from( + [(parent, node["go_id"]) for node in elements for parent in node["parents"]] + ) + + swiss_to_go_mapping = self._get_swiss_to_go_mapping() + # Add SwissProt proteins and their associations with GO terms + for swiss_id, swiss_info in swiss_to_go_mapping.items(): + g.add_node(swiss_id, node_type=f"{self._SWISS_DATA_INIT}", **swiss_info) + for go_id in swiss_info.get("go_ids", []): + if go_id in g.nodes: + g.add_edges_from((swiss_id, go_id)) print("Compute transitive closure") return nx.transitive_closure_dag(g) - @staticmethod - def term_callback(term: fastobo.term.TermFrame) -> Optional[Dict]: + def term_callback(self, term: fastobo.term.TermFrame) -> Union[Dict, bool]: """ Extracts information from a Gene Ontology (GO) term document. It also checks if the term is marked as obsolete and skips such terms. @@ -222,17 +249,25 @@ def term_callback(term: fastobo.term.TermFrame) -> Optional[Dict]: name = None for clause in term: + if isinstance(clause, fastobo.term.NamespaceClause): + if ( + self.go_branch != self._ALL_GO_BRANCHES + and clause.namespace != self._GO_BRANCH_NAMESPACE[self.go_branch] + ): + return False + + if isinstance(clause, fastobo.term.IsObsoleteClause): + if clause.obsolete: + # if the term contains clause as obsolete as true, skips this term + return False + if isinstance(clause, fastobo.term.IsAClause): parents.append(_GOUniprotDataExtractor._parse_go_id(clause.term)) elif isinstance(clause, fastobo.term.NameClause): name = clause.name - elif isinstance(clause, fastobo.term.IsObsoleteClause): - if clause.obsolete: - # if the term contains clause as obsolete as true, skips this term - return None return { - "id": _GOUniprotDataExtractor._parse_go_id(term.id), + "go_id": _GOUniprotDataExtractor._parse_go_id(term.id), "parents": parents, "name": name, } @@ -249,41 +284,30 @@ def _parse_go_id(go_id: str) -> int: str: The parsed and normalized GO term ID. """ # `is_a` clause has GO id in the following format: - # is_a: GO:0009968 ! negative regulation of signal transduction + # GO:0009968 ! negative regulation of signal transduction return int(str(go_id).split(":")[1].split("!")[0].strip()) - def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: + def _get_swiss_to_go_mapping(self) -> Dict[str, Dict[str, Union[str, List[int]]]]: """ Parses the Swiss-Prot data and returns a mapping from Gene Ontology (GO) data ID to Swiss-Prot ID along with the sequence representation of the protein. This mapping is necessary because the GO data does not include the protein sequence representation. + Note: + Check below link for keyword details. + https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt + Returns: Dict[int, Dict[str, str]]: A dictionary where the keys are GO data IDs (int) and the values are dictionaries containing: - "sequence" (str): The protein sequence. - "swiss_ident" (str): The unique identifier for each Swiss-Prot record. """ - # # https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt - # --------- --------------------------- ------------------------------ - # Line code Content Occurrence in an entry - # --------- --------------------------- ------------------------------ - # ID Identifier (keyword) Once; starts a keyword entry - # IC Identifier (category) Once; starts a category entry - # AC Accession (KW-xxxx) Once - # DE Definition Once or more - # SY Synonyms Optional; once or more - # GO Gene ontology (GO) mapping Optional; once or more - # HI Hierarchy Optional; once or more - # WW Relevant WWW site Optional; once or more - # CA Category Once per keyword entry; - # absent in category entries - # // Terminator Once; ends an entry - # --------------------------------------------------------------------------- print("Parsing swiss uniprot raw data....") - swiss_go_mapping = {} + swiss_to_go_mapping = {} + swiss_data = SwissProt.parse( open( os.path.join(self.raw_dir, self.raw_file_names_dict["SwissUniProt"]), @@ -295,68 +319,67 @@ def _get_go_swiss_data_mapping(self) -> Dict[int, Dict[str, str]]: if record.data_class != "Reviewed": # To consider only manually-annotated swiss data continue - # Cross-reference has mapping for each protein to each type of data set + + go_ids = [] for cross_ref in record.cross_references: if cross_ref[0] == self._GO_DATA_INIT: - # Only consider cross-reference related to GO dataset - go_id = _GOUniprotDataExtractor._parse_go_id(cross_ref[1]) - swiss_go_mapping[go_id] = { - "sequence": record.sequence, - "swiss_ident": record.entry_name, # Unique identifier for each swiss data record - } - return swiss_go_mapping + + # One swiss data protein can correspond to many GO data instances + go_ids.append(cross_ref[1]) + + swiss_to_go_mapping[record.entry_name] = { + "sequence": record.sequence, + "accessions": ",".join(record.accessions), + "go_ids": go_ids, + } + + return swiss_to_go_mapping def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: """ - Preparation step before creating splits, - uses the graph created by _extract_go_class_hierarchy() to extract the - raw data in Dataframe format with extra columns corresponding to each multi-label class. + Uses the graph created by _extract_class_hierarchy() to extract the + raw data in DataFrame format with extra columns corresponding to each multi-label class. Data Format: pd.DataFrame - - Column 0 : ID (Identifier for GO data instance) - - Column 1 : Name of the protein - - Column 2 : Sequence representation of the protein - - Column 3 : Unique identifier of the protein from swiss dataset. - - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating where the - data instance belong to this class or not. + - Column 0 : swiss_id (Identifier for SwissProt protein) + - Column 1 : Accession of the protein + - Column 2 : GO IDs (associated GO terms) + - Column 3 : Sequence of the protein + - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the + protein is associated with this GO term. + Args: g (nx.DiGraph): The class hierarchy graph. Returns: pd.DataFrame: The raw dataset created from the graph. """ - sequences = nx.get_node_attributes(g, "sequence") - names = nx.get_node_attributes(g, "name") - swiss_idents = nx.get_node_attributes(g, "swiss_ident") - print(f"Processing graph") - # Gets list of node ids, names, sequences, swiss identifier where sequence is not empty/None. - data_list = [] - for node_id, sequence in sequences.items(): - if sequence: - data_list.append( - ( - node_id, - names.get(node_id), - sequence, - swiss_idents.get(node_id), - ) - ) - - node_ids, names_list, sequences_list, swiss_identifier_list = zip(*data_list) - data = OrderedDict(id=node_ids) # ID column at index 0 - - data["name"] = names_list # Name column at index 1 - data["sequence"] = ( - sequences_list # Sequence (data representation) column at index 2 + sequences, accessions, go_ids, swiss_nodes, go_nodes = [], [], [], [], [] + for node_id, attr in g.nodes(data=True): + if attr.get("node_type") == self._SWISS_DATA_INIT: + if attr["sequence"]: + sequences.append(attr["sequence"]) + accessions.append(attr["accessions"]) + go_ids.append(attr["go_ids"]) + swiss_nodes.append(node_id) + elif attr.get("node_type") == self._GO_DATA_INIT: + go_nodes.append(node_id) + + data = OrderedDict( + swiss_id=swiss_nodes, # swiss_id column at index 0 + accession=accessions, # Accession column at index 1 + go_ids=go_ids, # Go_ids (data representation) column at index 2 + sequence=sequences, # Sequence column at index 3 ) - data["swiss_ident"] = swiss_identifier_list # Swiss_ident column at index 3 - # Assuming select_classes is implemented and returns a list of class IDs - for n in self.select_classes(g): - data[n] = [ - ((n in g.predecessors(node)) or (n == node)) for node in node_ids + # For each selected GO node, a new column is added to data with True/False values indicating whether the + # SwissProt node is associated with that GO node. + go_subgraph = g.subgraph(go_nodes).copy() + for go_node in self.select_classes(go_subgraph): + data[go_node] = [ + go_node in g.successors(swiss_node) for swiss_node in swiss_nodes ] data = pd.DataFrame(data) @@ -495,6 +518,9 @@ def _name(self) -> str: Returns: str: The dataset name, formatted with the current threshold value. """ + if self.go_branch != self._ALL_GO_BRANCHES: + return f"GO{self.THRESHOLD}_{self.go_branch}" + return f"GO{self.THRESHOLD}" def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: @@ -519,16 +545,10 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: - The `THRESHOLD` attribute should be defined in the subclass. - Nodes without a 'sequence' attribute are ignored in the successor count. """ - sequences = nx.get_node_attributes(g, "sequence") nodes = [] for node in g.nodes: # Count the number of successors (child nodes) for each node - no_of_successors = 0 - for s_node in g.successors(node): - if sequences.get(s_node, None): - no_of_successors += 1 - - if no_of_successors >= self.THRESHOLD: + if len(list(g.successors(node))) >= self.THRESHOLD: nodes.append(node) nodes.sort() From 499fafc7707c60b796ef84f06ad7efaa3b1ffa80 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 Aug 2024 14:51:31 +0200 Subject: [PATCH 16/30] update _graph_to_raw_dataset method - combines the swiss data with GO data --- chebai/preprocessing/datasets/go_uniprot.py | 185 +++++++++++--------- 1 file changed, 100 insertions(+), 85 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index db8c6791..e6568eb5 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -46,7 +46,7 @@ class _GOUniprotDataExtractor(_DynamicDataset, ABC): _GO_DATA_INIT = "GO" _SWISS_DATA_INIT = "SWISS" - # ---- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset` method) ------ + # -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset` # "swiss_id" at row index 0 # "accession" at row index 1 # "go_ids" at row index 2 @@ -76,6 +76,20 @@ def __init__(self, **kwargs): @classmethod def _get_go_branch(cls, **kwargs) -> str: + """ + Retrieves the Gene Ontology (GO) branch based on provided keyword arguments. + This method checks if a valid GO branch value is provided in the keyword arguments. + + Args: + **kwargs: Arbitrary keyword arguments. Specifically looks for: + - "go_branch" (str): The desired GO branch. + Returns: + str: The GO branch value. This will be one of the allowed values. + + Raises: + ValueError: If the provided 'go_branch' value is not in the allowed list of values. + """ + go_branch_value = kwargs.get("go_branch", cls._ALL_GO_BRANCHES) allowed_values = list(cls._GO_BRANCH_NAMESPACE.keys()) + [cls._ALL_GO_BRANCHES] if go_branch_value not in allowed_values: @@ -177,8 +191,7 @@ def _download_swiss_uni_prot_data(self) -> Optional[str]: def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: """ Extracts the class hierarchy from the GO ontology. - Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with GO term data - and corresponding Swiss-Prot data (obtained via `_get_go_swiss_data_mapping`). + Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with GO term data. Args: data_path (str): The path to the GO ontology. @@ -214,26 +227,17 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: # Add GO term nodes to the graph and their hierarchical ontology for n in elements: - g.add_node(n["go_id"], node_type=f"{self._GO_DATA_INIT}", **n) + g.add_node(n["go_id"], **n) g.add_edges_from( [(parent, node["go_id"]) for node in elements for parent in node["parents"]] ) - swiss_to_go_mapping = self._get_swiss_to_go_mapping() - # Add SwissProt proteins and their associations with GO terms - for swiss_id, swiss_info in swiss_to_go_mapping.items(): - g.add_node(swiss_id, node_type=f"{self._SWISS_DATA_INIT}", **swiss_info) - for go_id in swiss_info.get("go_ids", []): - if go_id in g.nodes: - g.add_edges_from((swiss_id, go_id)) - print("Compute transitive closure") return nx.transitive_closure_dag(g) def term_callback(self, term: fastobo.term.TermFrame) -> Union[Dict, bool]: """ Extracts information from a Gene Ontology (GO) term document. - It also checks if the term is marked as obsolete and skips such terms. Args: term: A Gene Ontology term Frame document. @@ -254,20 +258,21 @@ def term_callback(self, term: fastobo.term.TermFrame) -> Union[Dict, bool]: self.go_branch != self._ALL_GO_BRANCHES and clause.namespace != self._GO_BRANCH_NAMESPACE[self.go_branch] ): + # if the term document is not related to given go branch (except `all`), skip this document. return False if isinstance(clause, fastobo.term.IsObsoleteClause): if clause.obsolete: - # if the term contains clause as obsolete as true, skips this term + # if the term document contains clause as obsolete as true, skips this document. return False if isinstance(clause, fastobo.term.IsAClause): - parents.append(_GOUniprotDataExtractor._parse_go_id(clause.term)) + parents.append(self._parse_go_id(clause.term)) elif isinstance(clause, fastobo.term.NameClause): name = clause.name return { - "go_id": _GOUniprotDataExtractor._parse_go_id(term.id), + "go_id": self._parse_go_id(term.id), "parents": parents, "name": name, } @@ -283,30 +288,76 @@ def _parse_go_id(go_id: str) -> int: Returns: str: The parsed and normalized GO term ID. """ - # `is_a` clause has GO id in the following format: + # `is_a` clause has GO id in the following formats: # GO:0009968 ! negative regulation of signal transduction + # GO:0046780 return int(str(go_id).split(":")[1].split("!")[0].strip()) - def _get_swiss_to_go_mapping(self) -> Dict[str, Dict[str, Union[str, List[int]]]]: + def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + """ + Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes + Swiss-Prot protein data and their associations with Gene Ontology (GO) terms. + + Note: + - GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value + indicates whether a Swiss-Prot protein is associated with that GO term. + - Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins + and GO terms. + + Data Format: pd.DataFrame + - Column 0 : swiss_id (Identifier for SwissProt protein) + - Column 1 : Accession of the protein + - Column 2 : GO IDs (associated GO terms) + - Column 3 : Sequence of the protein + - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the + protein is associated with this GO term. + + Args: + g (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset created from the graph. + """ + print(f"Processing graph") + + data_df = self._get_swiss_to_go_mapping() + + # Initialize the GO term labels/columns to False + data_df[self.select_classes(g)] = False + # Set True for the corresponding GO IDs in the DataFrame go labels/columns + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + # This filters the DataFrame to include only the rows where at least one value in the row from 5th column + # onwards is True/non-zero. + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + def _get_swiss_to_go_mapping(self) -> pd.DataFrame: """ - Parses the Swiss-Prot data and returns a mapping from Gene Ontology (GO) data ID to Swiss-Prot ID - along with the sequence representation of the protein. + Parses the Swiss-Prot data and returns a DataFrame mapping Swiss-Prot records to Gene Ontology (GO) data. - This mapping is necessary because the GO data does not include the protein sequence representation. + The DataFrame includes the following columns: + - "swiss_id": The unique identifier for each Swiss-Prot record. + - "sequence": The protein sequence. + - "accessions": Comma-separated list of accession numbers. + - "go_ids": List of GO IDs associated with the Swiss-Prot record. Note: - Check below link for keyword details. + This mapping is necessary because the GO data does not include the protein sequence representation. + + Check the link below for keyword details: https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt Returns: - Dict[int, Dict[str, str]]: A dictionary where the keys are GO data IDs (int) and the values are - dictionaries containing: - - "sequence" (str): The protein sequence. - - "swiss_ident" (str): The unique identifier for each Swiss-Prot record. + pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with its associated GO data. """ + print("Parsing swiss uniprot raw data....") - swiss_to_go_mapping = {} + swiss_ids, sequences, accessions, go_ids_list = [], [], [], [] swiss_data = SwissProt.parse( open( @@ -320,73 +371,35 @@ def _get_swiss_to_go_mapping(self) -> Dict[str, Dict[str, Union[str, List[int]]] # To consider only manually-annotated swiss data continue + if not record.sequence: + # Consider protein with only sequence representation + continue + go_ids = [] for cross_ref in record.cross_references: if cross_ref[0] == self._GO_DATA_INIT: - # One swiss data protein can correspond to many GO data instances - go_ids.append(cross_ref[1]) + go_ids.append(self._parse_go_id(cross_ref[1])) - swiss_to_go_mapping[record.entry_name] = { - "sequence": record.sequence, - "accessions": ",".join(record.accessions), - "go_ids": go_ids, - } - - return swiss_to_go_mapping - - def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: - """ - Uses the graph created by _extract_class_hierarchy() to extract the - raw data in DataFrame format with extra columns corresponding to each multi-label class. - - Data Format: pd.DataFrame - - Column 0 : swiss_id (Identifier for SwissProt protein) - - Column 1 : Accession of the protein - - Column 2 : GO IDs (associated GO terms) - - Column 3 : Sequence of the protein - - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the - protein is associated with this GO term. + if not go_ids: + # Swiss protein with no mapping to Gene Ontology is skipped + continue - Args: - g (nx.DiGraph): The class hierarchy graph. + go_ids.sort() - Returns: - pd.DataFrame: The raw dataset created from the graph. - """ - print(f"Processing graph") + swiss_ids.append(record.entry_name) + sequences.append(record.sequence) + accessions.append(",".join(record.accessions)) + go_ids_list.append(go_ids) - sequences, accessions, go_ids, swiss_nodes, go_nodes = [], [], [], [], [] - for node_id, attr in g.nodes(data=True): - if attr.get("node_type") == self._SWISS_DATA_INIT: - if attr["sequence"]: - sequences.append(attr["sequence"]) - accessions.append(attr["accessions"]) - go_ids.append(attr["go_ids"]) - swiss_nodes.append(node_id) - elif attr.get("node_type") == self._GO_DATA_INIT: - go_nodes.append(node_id) - - data = OrderedDict( - swiss_id=swiss_nodes, # swiss_id column at index 0 + data_dict = OrderedDict( + swiss_id=swiss_ids, # swiss_id column at index 0 accession=accessions, # Accession column at index 1 - go_ids=go_ids, # Go_ids (data representation) column at index 2 + go_ids=go_ids_list, # Go_ids (data representation) column at index 2 sequence=sequences, # Sequence column at index 3 ) - # For each selected GO node, a new column is added to data with True/False values indicating whether the - # SwissProt node is associated with that GO node. - go_subgraph = g.subgraph(go_nodes).copy() - for go_node in self.select_classes(go_subgraph): - data[go_node] = [ - go_node in g.successors(swiss_node) for swiss_node in swiss_nodes - ] - - data = pd.DataFrame(data) - # This filters the DataFrame to include only the rows where at least one value in the row from 5th column - # onwards is True/non-zero. - data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)] - return data + return pd.DataFrame(data_dict) # ------------------------------ Phase: Setup data ----------------------------------- def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: @@ -516,7 +529,7 @@ def _name(self) -> str: Returns the name of the dataset. Returns: - str: The dataset name, formatted with the current threshold value. + str: The dataset name, formatted with the current threshold value and/or given go_branch. """ if self.go_branch != self._ALL_GO_BRANCHES: return f"GO{self.THRESHOLD}_{self.go_branch}" @@ -530,6 +543,9 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: This method iterates over the nodes in the graph, counting the number of successors for each node. Nodes with a number of successors greater than or equal to the defined threshold are selected. + Note: + The input graph must be transitive closure of a directed acyclic graph. + Args: g (nx.DiGraph): The graph representing the dataset. Each node should have a 'sequence' attribute. *args: Additional positional arguments (not used). @@ -543,11 +559,10 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: Notes: - The `THRESHOLD` attribute should be defined in the subclass. - - Nodes without a 'sequence' attribute are ignored in the successor count. """ nodes = [] for node in g.nodes: - # Count the number of successors (child nodes) for each node + # Count the number of successors (direct child nodes) for each node if len(list(g.successors(node))) >= self.THRESHOLD: nodes.append(node) From 19c47c17d0412a83a048d82d25693157c9afee82 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 Aug 2024 20:44:32 +0200 Subject: [PATCH 17/30] fix tokenizing process in reader class for protein --- chebai/preprocessing/reader.py | 76 +++++++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 16 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 777d64d9..d00a0be9 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -146,11 +146,21 @@ def _get_token_index(self, token: str) -> int: return self.cache.index(str(token)) + EMBEDDING_OFFSET def _read_data(self, raw_data: str) -> List[int]: - """Read and tokenize raw data.""" + """ + Reads and tokenizes raw SMILES data into a list of token indices. + + Args: + raw_data (str): The raw SMILES string to be tokenized. + + Returns: + List[int]: A list of integers representing the indices of the SMILES tokens. + """ return [self._get_token_index(v[1]) for v in _tokenize(raw_data)] def on_finish(self) -> None: - """Write contents of self.cache into tokens.txt.""" + """ + Saves the current cache of tokens to the token file. This method is called after all data processing is complete. + """ with open(self.token_path, "w") as pk: print(f"saving {len(self.cache)} tokens to {self.token_path}...") print(f"first 10 tokens: {self.cache[:10]}") @@ -324,11 +334,15 @@ def _read_data(self, raw_data: str) -> List[int]: class ProteinDataReader(DataReader): """ - Data reader for Protein data using protein-sequence tokens. + Data reader for protein sequences using amino acid tokens. This class processes raw protein sequences into a format + suitable for model input by tokenizing them and assigning unique indices to each token. + + Note: + Refer for amino acid sequence: https://en.wikipedia.org/wiki/Protein_primary_structure Args: - collator_kwargs: Optional dictionary of keyword arguments for the collator. - token_path: Optional path for the token file. + collator_kwargs (Optional[Dict[str, Any]]): Optional dictionary of keyword arguments for configuring the collator. + token_path (Optional[str]): Path to the token file. If not provided, it will be created automatically. kwargs: Additional keyword arguments. """ @@ -336,29 +350,59 @@ class ProteinDataReader(DataReader): @classmethod def name(cls) -> str: - """Returns the name of the data reader.""" - return "sequence_token" + """ + Returns the name of the data reader. This method identifies the specific type of data reader. + + Returns: + str: The name of the data reader, which is "protein_token". + """ + return "protein_token" def __init__(self, *args, **kwargs): + """ + Initializes the ProteinDataReader, loading existing tokens from the specified token file. + + Args: + *args: Additional positional arguments passed to the base class. + **kwargs: Additional keyword arguments passed to the base class. + """ super().__init__(*args, **kwargs) + # Load the existing tokens from the token file into a cache with open(self.token_path, "r") as pk: self.cache = [x.strip() for x in pk] def _get_token_index(self, token: str) -> int: - """Returns a unique number for each token, automatically adds new tokens.""" - if not str(token) in self.cache: + """ + Returns a unique index for each token (amino acid). If the token is not already in the cache, it is added. + + Args: + token (str): The amino acid token to retrieve or add. + + Returns: + int: The index of the token, offset by the predefined EMBEDDING_OFFSET. + """ + if str(token) not in self.cache: self.cache.append(str(token)) return self.cache.index(str(token)) + EMBEDDING_OFFSET def _read_data(self, raw_data: str) -> List[int]: - """Read and tokenize raw data.""" - return [self._get_token_index(v[1]) for v in _tokenize(raw_data)] + """ + Reads and tokenizes raw protein sequence data into a list of token indices. + + Args: + raw_data (str): The raw protein sequence to be tokenized (e.g., "MKTFF..."). + + Returns: + List[int]: A list of integers representing the indices of the amino acid tokens. + """ + # In the case of protein sequences, each amino acid is typically represented by a single letter. + return [self._get_token_index(aa) for aa in raw_data] def on_finish(self) -> None: - """Write contents of self.cache into tokens.txt.""" + """ + Saves the current cache of tokens to the token file. This method is called after all data processing is complete. + """ with open(self.token_path, "w") as pk: - print(f"saving {len(self.cache)} tokens to {self.token_path}...") - print(f"first 3 sequences tokens: {self.cache[:3]}") - for token in self.cache[:3]: - print(f"Sequence Token: {token}") + print(f"Saving {len(self.cache)} tokens to {self.token_path}...") + print(f"First 10 tokens: {self.cache[:10]}") pk.writelines([f"{c}\n" for c in self.cache]) From ecb276ace806f27c0baa212528410727e7258422 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 Aug 2024 20:46:21 +0200 Subject: [PATCH 18/30] protein tokens - 20 natural amino acid tokens - 20 natural amino acid notation tokens as per below wiki - https://en.wikipedia.org/wiki/Protein_primary_structure --- .../bin/protein_token/tokens.txt | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 chebai/preprocessing/bin/protein_token/tokens.txt diff --git a/chebai/preprocessing/bin/protein_token/tokens.txt b/chebai/preprocessing/bin/protein_token/tokens.txt new file mode 100644 index 00000000..72ad1b6d --- /dev/null +++ b/chebai/preprocessing/bin/protein_token/tokens.txt @@ -0,0 +1,20 @@ +M +S +I +G +A +T +R +L +Q +N +D +K +Y +P +C +F +W +E +V +H From 5f9ff93bad38b81f06d4823e51a3cd5baef9ce7f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 Aug 2024 22:47:35 +0200 Subject: [PATCH 19/30] minor updates --- chebai/preprocessing/bin/protein_token/tokens.txt | 5 +++++ chebai/preprocessing/datasets/base.py | 6 +++--- chebai/preprocessing/datasets/chebi.py | 8 +++++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/chebai/preprocessing/bin/protein_token/tokens.txt b/chebai/preprocessing/bin/protein_token/tokens.txt index 72ad1b6d..0d32d3ac 100644 --- a/chebai/preprocessing/bin/protein_token/tokens.txt +++ b/chebai/preprocessing/bin/protein_token/tokens.txt @@ -18,3 +18,8 @@ W E V H +X +Z +B +U +O diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index f8d3892c..22cdb9db 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -868,7 +868,7 @@ def _generate_dynamic_splits(self) -> None: This method loads encoded data and generates train, validation, and test splits based on the loaded data. """ - print("Generate dynamic splits...") + print("\nGenerate dynamic splits...") df_train, df_val, df_test = self._get_data_splits() # Generate splits.csv file to store ids of each corresponding split @@ -922,7 +922,7 @@ def get_test_split( Raises: ValueError: If the DataFrame does not contain a column named "labels". """ - print("\nGet test data split") + print("Get test data split") labels_list = df["labels"].tolist() @@ -1003,7 +1003,7 @@ def _retrieve_splits_from_csv(self) -> None: It then loads the encoded data (`data.pt`) and filters it based on the IDs retrieved from splits.csv to reconstruct the train, validation, and test splits. """ - print(f"Loading splits from {self.splits_file_path}...") + print(f"\nLoading splits from {self.splits_file_path}...") splits_df = pd.read_csv(self.splits_file_path) filename = self.processed_file_names_dict["data"] diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index c17347b4..23f02f68 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -243,15 +243,18 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: """ with open(data_path, encoding="utf-8") as chebi: chebi = "\n".join(l for l in chebi if not l.startswith("xref:")) + elements = [ term_callback(clause) for clause in fastobo.loads(chebi) if clause and ":" in str(clause.id) ] + g = nx.DiGraph() for n in elements: g.add_node(n["id"], **n) g.add_edges_from([(p, q["id"]) for q in elements for p in q["parents"]]) + print("Compute transitive closure") return nx.transitive_closure_dag(g) @@ -582,6 +585,9 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: This method iterates over the nodes in the graph, counting the number of successors for each node. Nodes with a number of successors greater than or equal to the defined threshold are selected. + Note: + The input graph must be transitive closure of a directed acyclic graph. + Args: g (nx.Graph): The graph representing the dataset. *args: Additional positional arguments (not used). @@ -595,7 +601,7 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: Notes: - The `THRESHOLD` attribute should be defined in the subclass of this class. - - Nodes without a 'sequence' attribute are ignored in the successor count. + - Nodes without a 'smiles' attribute are ignored in the successor count. """ smiles = nx.get_node_attributes(g, "smiles") nodes = list( From b9169943c64a767a098fad8d2497134bb270e1a7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 15 Aug 2024 00:52:37 +0200 Subject: [PATCH 20/30] filter out swiss protein as per given criterias in paper - ambiguous_amino_acids - sequence_length - experimental_evidence_codes --- .../bin/protein_token/tokens.txt | 5 --- chebai/preprocessing/datasets/go_uniprot.py | 38 ++++++++++++++++--- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/chebai/preprocessing/bin/protein_token/tokens.txt b/chebai/preprocessing/bin/protein_token/tokens.txt index 0d32d3ac..72ad1b6d 100644 --- a/chebai/preprocessing/bin/protein_token/tokens.txt +++ b/chebai/preprocessing/bin/protein_token/tokens.txt @@ -18,8 +18,3 @@ W E V H -X -Z -B -U -O diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index e6568eb5..5668f5b7 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -348,6 +348,12 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: Note: This mapping is necessary because the GO data does not include the protein sequence representation. + Quote from the DeepGo Paper: + `We select proteins with annotations having experimental evidence codes + (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC) and filter the proteins by a + maximum length of 1002, ignoring proteins with ambiguous amino acid codes + (B, O, J, U, X, Z) in their sequence.` + Check the link below for keyword details: https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt @@ -366,30 +372,50 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: ) ) + experimental_evidence_codes = { + "EXP", + "IDA", + "IPI", + "IMP", + "IGI", + "IEP", + "TAS", + "IC", + } + ambiguous_amino_acids = {"B", "O", "J", "U", "X", "Z"} + max_length = 1002 + for record in swiss_data: if record.data_class != "Reviewed": # To consider only manually-annotated swiss data continue - if not record.sequence: - # Consider protein with only sequence representation + if not record.sequence or record.sequence_length > max_length: + # Consider protein with only sequence representation and a maximum length of 1002 + continue + + if any(aa in ambiguous_amino_acids for aa in record.sequence): + # Skip proteins with ambiguous amino acid codes continue go_ids = [] + evidence_codes = set() + for cross_ref in record.cross_references: if cross_ref[0] == self._GO_DATA_INIT: # One swiss data protein can correspond to many GO data instances go_ids.append(self._parse_go_id(cross_ref[1])) + if len(cross_ref) > 3: + evidence_codes.add(cross_ref[3].split(":")[0]) - if not go_ids: - # Swiss protein with no mapping to Gene Ontology is skipped + if not go_ids or not (experimental_evidence_codes & evidence_codes): + # Skip Swiss proteins without mapping to GO data or without the required experimental evidence codes continue - go_ids.sort() - swiss_ids.append(record.entry_name) sequences.append(record.sequence) accessions.append(",".join(record.accessions)) + go_ids.sort() go_ids_list.append(go_ids) data_dict = OrderedDict( From 079269baae9f813a4a690163d2f84075aa45a5d3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 15 Aug 2024 23:01:54 +0200 Subject: [PATCH 21/30] fixes: go_branch filtering, protein sequence --- chebai/preprocessing/datasets/go_uniprot.py | 37 +++++++++++++-------- chebai/preprocessing/reader.py | 33 ++++++++++++++++++ configs/data/go250.yml | 3 ++ 3 files changed, 59 insertions(+), 14 deletions(-) create mode 100644 configs/data/go250.yml diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 5668f5b7..04feabad 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -8,7 +8,7 @@ # https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt # https://www.uniprot.org/uniprotkb -__all__ = ["GoUniProtOver100", "GoUniProtOver50"] +__all__ = ["GoUniProtOver250", "GoUniProtOver50"] import gzip import os @@ -229,7 +229,12 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: for n in elements: g.add_node(n["go_id"], **n) g.add_edges_from( - [(parent, node["go_id"]) for node in elements for parent in node["parents"]] + [ + (parent_id, node_id) + for node_id in g.nodes + for parent_id in g.nodes[node_id]["parents"] + if parent_id in g.nodes + ] ) print("Compute transitive closure") @@ -256,7 +261,8 @@ def term_callback(self, term: fastobo.term.TermFrame) -> Union[Dict, bool]: if isinstance(clause, fastobo.term.NamespaceClause): if ( self.go_branch != self._ALL_GO_BRANCHES - and clause.namespace != self._GO_BRANCH_NAMESPACE[self.go_branch] + and clause.namespace.escaped + != self._GO_BRANCH_NAMESPACE[self.go_branch] ): # if the term document is not related to given go branch (except `all`), skip this document. return False @@ -332,6 +338,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: # This filters the DataFrame to include only the rows where at least one value in the row from 5th column # onwards is True/non-zero. + # Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least + # one GO term from the set of the GO terms for the model` data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] return data_df @@ -372,7 +380,7 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: ) ) - experimental_evidence_codes = { + EXPERIMENTAL_EVIDENCE_CODES = { "EXP", "IDA", "IPI", @@ -382,19 +390,20 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: "TAS", "IC", } - ambiguous_amino_acids = {"B", "O", "J", "U", "X", "Z"} - max_length = 1002 + # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 + AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"} + MAX_LENGTH = 1002 for record in swiss_data: if record.data_class != "Reviewed": # To consider only manually-annotated swiss data continue - if not record.sequence or record.sequence_length > max_length: + if not record.sequence or record.sequence_length > MAX_LENGTH: # Consider protein with only sequence representation and a maximum length of 1002 continue - if any(aa in ambiguous_amino_acids for aa in record.sequence): + if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence): # Skip proteins with ambiguous amino acid codes continue @@ -408,7 +417,7 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: if len(cross_ref) > 3: evidence_codes.add(cross_ref[3].split(":")[0]) - if not go_ids or not (experimental_evidence_codes & evidence_codes): + if not go_ids or not (EXPERIMENTAL_EVIDENCE_CODES & evidence_codes): # Skip Swiss proteins without mapping to GO data or without the required experimental evidence codes continue @@ -601,17 +610,17 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: return nodes -class GoUniProtOver100(_GoUniProtOverX): +class GoUniProtOver250(_GoUniProtOverX): """ - A class for extracting data from the Gene Ontology (GO) dataset with a threshold of 100 for selecting classes. + A class for extracting data from the Gene Ontology (GO) dataset with a threshold of 250 for selecting classes. - Inherits from `_GoUniProtOverX` and sets the threshold for selecting classes to 100. + Inherits from `_GoUniProtOverX` and sets the threshold for selecting classes to 250. Attributes: - THRESHOLD (int): The threshold for selecting classes (100). + THRESHOLD (int): The threshold for selecting classes (250). """ - THRESHOLD: int = 100 + THRESHOLD: int = 250 def label_number(self) -> int: """ diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index d00a0be9..d10185ed 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -348,6 +348,30 @@ class ProteinDataReader(DataReader): COLLATOR = RaggedCollator + # 20 natural amino acid notation + AA_LETTER = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", + ] + @classmethod def name(cls) -> str: """ @@ -381,6 +405,15 @@ def _get_token_index(self, token: str) -> int: Returns: int: The index of the token, offset by the predefined EMBEDDING_OFFSET. """ + if str(token) not in self.AA_LETTER: + raise KeyError( + f"Invalid token '{token}' encountered. " + f"Please ensure that the input only contains valid amino acids " + f"20 Valid natural amino acid notation: {self.AA_LETTER}" + f"Refer to the amino acid sequence details here: " + f"https://en.wikipedia.org/wiki/Protein_primary_structure" + ) + if str(token) not in self.cache: self.cache.append(str(token)) return self.cache.index(str(token)) + EMBEDDING_OFFSET diff --git a/configs/data/go250.yml b/configs/data/go250.yml new file mode 100644 index 00000000..6724471b --- /dev/null +++ b/configs/data/go250.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.go_uniprot.GoUniProtOver250 +init_args: + go_branch: "BP" From 638598a4459268c4279b22bd94f2a82b7b596c64 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 16 Aug 2024 01:54:41 +0200 Subject: [PATCH 22/30] update logic to select go classes based on proteins dataset --- chebai/preprocessing/datasets/go_uniprot.py | 102 +++++++++++++++----- 1 file changed, 77 insertions(+), 25 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 04feabad..86643a0d 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -329,7 +329,8 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: data_df = self._get_swiss_to_go_mapping() # Initialize the GO term labels/columns to False - data_df[self.select_classes(g)] = False + data_df[self.select_classes(g, data_df=data_df)] = False + # Set True for the corresponding GO IDs in the DataFrame go labels/columns for index, row in data_df.iterrows(): for go_id in row["go_ids"]: @@ -408,17 +409,25 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: continue go_ids = [] - evidence_codes = set() for cross_ref in record.cross_references: if cross_ref[0] == self._GO_DATA_INIT: # One swiss data protein can correspond to many GO data instances + + if len(cross_ref) <= 3: + # No evidence code + continue + + # https://github.com/bio-ontology-research-group/deepgo/blob/master/get_functions.py#L63-L66 + evidence_code = cross_ref[3].split(":")[0] + if evidence_code not in EXPERIMENTAL_EVIDENCE_CODES: + # Skip GO id without the required experimental evidence codes + continue + go_ids.append(self._parse_go_id(cross_ref[1])) - if len(cross_ref) > 3: - evidence_codes.add(cross_ref[3].split(":")[0]) - if not go_ids or not (EXPERIMENTAL_EVIDENCE_CODES & evidence_codes): - # Skip Swiss proteins without mapping to GO data or without the required experimental evidence codes + if not go_ids: + # Skip Swiss proteins without mapping to GO data continue swiss_ids.append(record.entry_name) @@ -571,43 +580,86 @@ def _name(self) -> str: return f"GO{self.THRESHOLD}" - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes( + self, g: nx.DiGraph, *args: Any, **kwargs: Dict[str, Any] + ) -> List[int]: """ - Selects classes from the GO dataset based on the number of successors meeting a specified threshold. + Selects classes (GO terms) from the Gene Ontology (GO) dataset based on the number of annotations meeting a + specified threshold. - This method iterates over the nodes in the graph, counting the number of successors for each node. - Nodes with a number of successors greater than or equal to the defined threshold are selected. + The selection process is based on the annotations of the GO terms with its ancestors across the dataset. - Note: - The input graph must be transitive closure of a directed acyclic graph. + Annotations are calculated by counting how many times each GO term, along with its ancestral hierarchy, + is annotated per protein across the dataset. + This means that for each protein, the GO terms associated with it are considered, and the entire hierarchical + structure (ancestors) of each GO term is taken into account. The total count for each GO term and its ancestors + reflects how frequently these terms are annotated across all proteins in the dataset. Args: - g (nx.DiGraph): The graph representing the dataset. Each node should have a 'sequence' attribute. + g (nx.DiGraph): The directed acyclic graph representing the GO dataset, where each node corresponds to a GO term. *args: Additional positional arguments (not used). - **kwargs: Additional keyword arguments (not used). + **kwargs: Additional keyword arguments, including: + - data_df (pd.DataFrame): A DataFrame containing the GO annotations for various proteins. + It should include a 'go_ids' column with the GO terms associated with each protein. Returns: - List: A sorted list of node IDs that meet the successor threshold criteria. + List[int]: A sorted list of selected GO term IDs that meet the annotation threshold criteria. Side Effects: - Writes the list of selected nodes to a file named "classes.txt" in the specified processed directory. + - Writes the list of selected GO term IDs to a file named "classes.txt" in the specified processed directory. + + Raises: + AttributeError: If the 'data_df' argument is not provided in kwargs. Notes: - - The `THRESHOLD` attribute should be defined in the subclass. + - The `THRESHOLD` attribute, which defines the minimum number of annotations required to select a GO term, should be defined in the subclass. """ - nodes = [] - for node in g.nodes: - # Count the number of successors (direct child nodes) for each node - if len(list(g.successors(node))) >= self.THRESHOLD: - nodes.append(node) + # Retrieve the DataFrame containing GO annotations per protein from the keyword arguments + data_df: pd.DataFrame = kwargs.get("data_df", None) + if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty: + raise AttributeError( + "The 'data_df' argument must be provided and must be a non-empty pandas DataFrame." + ) + + print(f"Selecting GO terms based on given threshold: {self.THRESHOLD} ...") - nodes.sort() + # https://github.com/bio-ontology-research-group/deepgo/blob/master/get_functions.py#L59-L77 + go_term_annot: Dict[int, int] = {} + for idx, row in data_df.iterrows(): + # Set will contain go terms associated with the protein, along with all the ancestors of those + # associated go terms + associated_go_ids_with_ancestors = set() + + # Collect all ancestors of the GO terms associated with this protein + for go_id in row["go_ids"]: + if go_id in g.nodes: + associated_go_ids_with_ancestors.add(go_id) + associated_go_ids_with_ancestors.update( + g.predecessors(go_id) + ) # Add all predecessors (ancestors) of go_id + + # Count the annotations for each go_id **`per protein`** + for go_id in associated_go_ids_with_ancestors: + if go_id not in go_term_annot: + go_term_annot[go_id] = 0 + go_term_annot[go_id] += 1 + + # Select GO terms that meet or exceed the threshold of annotations + selected_nodes: List[int] = [ + go_id + for go_id in g.nodes + if go_id in go_term_annot and go_term_annot[go_id] >= self.THRESHOLD + ] + + # Sort the selected nodes (optional but often useful for consistent output) + selected_nodes.sort() # Write the selected node IDs/classes to the file filename = "classes.txt" with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: - fout.writelines(str(node) + "\n" for node in nodes) - return nodes + fout.writelines(str(node) + "\n" for node in selected_nodes) + + return selected_nodes class GoUniProtOver250(_GoUniProtOverX): From 9200b737ba4860c5b1ebe4eb07fa8757993b0665 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 16 Aug 2024 17:00:55 +0200 Subject: [PATCH 23/30] fix: dataframe column addition performance warning PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling frame.insert many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use newframe = frame.copy() data_df[self.select_classes(g, data_df=data_df)] = False --- chebai/preprocessing/datasets/go_uniprot.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 86643a0d..de3d2c0d 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -329,7 +329,11 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: data_df = self._get_swiss_to_go_mapping() # Initialize the GO term labels/columns to False - data_df[self.select_classes(g, data_df=data_df)] = False + selected_classes = self.select_classes(g, data_df=data_df) + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=selected_classes + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) # Set True for the corresponding GO IDs in the DataFrame go labels/columns for index, row in data_df.iterrows(): From f9c10f7fa8abe31abee2ee9805c4253218365d12 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 25 Aug 2024 14:32:35 +0200 Subject: [PATCH 24/30] consistent prefix "GOUniProt" for all classes --- chebai/preprocessing/datasets/go_uniprot.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index de3d2c0d..cc82793d 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -8,7 +8,7 @@ # https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt # https://www.uniprot.org/uniprotkb -__all__ = ["GoUniProtOver250", "GoUniProtOver50"] +__all__ = ["GOUniProtOver250", "GOUniProtOver50"] import gzip import os @@ -29,7 +29,7 @@ from chebai.preprocessing.datasets.base import _DynamicDataset -class _GOUniprotDataExtractor(_DynamicDataset, ABC): +class _GOUniProtDataExtractor(_DynamicDataset, ABC): """ A class for extracting and processing data from the Gene Ontology (GO) dataset and the Swiss UniProt dataset. @@ -72,7 +72,7 @@ class _GOUniprotDataExtractor(_DynamicDataset, ABC): def __init__(self, **kwargs): self.go_branch: str = self._get_go_branch(**kwargs) - super(_GOUniprotDataExtractor, self).__init__(**kwargs) + super(_GOUniProtDataExtractor, self).__init__(**kwargs) @classmethod def _get_go_branch(cls, **kwargs) -> str: @@ -547,7 +547,7 @@ def raw_file_names_dict(self) -> dict: return {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"} -class _GoUniProtOverX(_GOUniprotDataExtractor, ABC): +class _GOUniProtOverX(_GOUniProtDataExtractor, ABC): """ A class for extracting data from the Gene Ontology (GO) dataset with a threshold for selecting classes based on the number of subclasses. @@ -666,11 +666,11 @@ def select_classes( return selected_nodes -class GoUniProtOver250(_GoUniProtOverX): +class GOUniProtOver250(_GOUniProtOverX): """ A class for extracting data from the Gene Ontology (GO) dataset with a threshold of 250 for selecting classes. - Inherits from `_GoUniProtOverX` and sets the threshold for selecting classes to 250. + Inherits from `_GOUniProtOverX` and sets the threshold for selecting classes to 250. Attributes: THRESHOLD (int): The threshold for selecting classes (250). @@ -690,11 +690,11 @@ def label_number(self) -> int: return 854 -class GoUniProtOver50(_GoUniProtOverX): +class GOUniProtOver50(_GOUniProtOverX): """ A class for extracting data from the Gene Ontology (GO) dataset with a threshold of 50 for selecting classes. - Inherits from `_GoUniProtOverX` and sets the threshold for selecting classes to 50. + Inherits from `_GOUniProtOverX` and sets the threshold for selecting classes to 50. Attributes: THRESHOLD (int): The threshold for selecting classes (50). From f39916b6f66d8c3622823957d6a49c4a7d30faa3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 25 Aug 2024 20:24:28 +0200 Subject: [PATCH 25/30] update go configs for new class names --- configs/data/go250.yml | 2 +- configs/data/go50.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/data/go250.yml b/configs/data/go250.yml index 6724471b..5598495c 100644 --- a/configs/data/go250.yml +++ b/configs/data/go250.yml @@ -1,3 +1,3 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.GoUniProtOver250 +class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver250 init_args: go_branch: "BP" diff --git a/configs/data/go50.yml b/configs/data/go50.yml index a3e8ca60..2ed4d14c 100644 --- a/configs/data/go50.yml +++ b/configs/data/go50.yml @@ -1 +1 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.GoUniProtOver50 +class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver50 From 4db76ce1b317448172058faab4eb756818cd4fc4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Sep 2024 11:38:50 +0200 Subject: [PATCH 26/30] extra documentation for ragged coll as per the comment - https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829 --- chebai/preprocessing/collate.py | 43 ++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/chebai/preprocessing/collate.py b/chebai/preprocessing/collate.py index 4e5e9e16..ecbcb876 100644 --- a/chebai/preprocessing/collate.py +++ b/chebai/preprocessing/collate.py @@ -41,19 +41,41 @@ def __call__(self, data: List[Dict]) -> XYData: class RaggedCollator(Collator): - """Collator for handling ragged data samples.""" + """ + Collator for handling ragged data samples, designed to support scenarios where some labels may be missing (None). + + This class is specifically designed for preparing batches of "ragged" data, where the samples may have varying sizes, + such as molecular representations or variable-length protein sequences. Additionally, it supports cases where some + of the data samples might be partially labeled, which is useful for certain loss functions that allow training + with incomplete or fuzzy data (e.g., fuzzy loss). + + During batching, the class pads the data samples to a uniform length, applies appropriate masks to differentiate + between valid and padded elements, and ensures that label misalignment is handled by filtering out unlabelled + data points. The indices of valid labels are stored in the `non_null_labels` field, which can be used later for + metrics computation such as F1-score or MSE, especially in cases where some data points lack labels. + + Reference: https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829 + """ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: - """Collate ragged data samples (i.e., samples of unequal size such as string representations of molecules) into - a batch. + """ + Collate ragged data samples (i.e., samples of unequal size, such as molecular sequences) into a batch. + + Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices + of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for + unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method + ensures alignment between features and labels. Args: - data (List[Union[Dict, Tuple]]): List of ragged data samples. + data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple + with 'features', 'labels', and 'ident'. Returns: - XYData: Batched data with appropriate padding and masks. + XYData: A batch of padded sequences and labels, including masks for valid positions and indices of + non-null labels for metric computation. """ model_kwargs: Dict = dict() + # Indices of non-null labels are stored in key `non_null_labels` of loss_kwargs. loss_kwargs: Dict = dict() if isinstance(data[0], tuple): @@ -64,18 +86,23 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: *((d["features"], d["labels"], d.get("ident")) for d in data) ) if any(x is not None for x in y): + # If any label is not None: (None, None, `1`, None) if any(x is None for x in y): + # If any label is None: (`None`, `None`, 1, `None`) non_null_labels = [i for i, r in enumerate(y) if r is not None] y = self.process_label_rows( tuple(ye for i, ye in enumerate(y) if i in non_null_labels) ) loss_kwargs["non_null_labels"] = non_null_labels else: + # If all labels are not None: (`0`, `2`, `1`, `3`) y = self.process_label_rows(y) else: + # If all labels are None : (`None`, `None`, `None`, `None`) y = None loss_kwargs["non_null_labels"] = [] + # Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions lens = torch.tensor(list(map(len, x))) model_kwargs["mask"] = torch.arange(max(lens))[None, :] < lens[:, None] model_kwargs["lens"] = lens @@ -89,7 +116,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: ) def process_label_rows(self, labels: Tuple) -> torch.Tensor: - """Process label rows by padding sequences. + """ + Process label rows by padding sequences to ensure uniform shape across the batch. + + This method pads the label rows, converting sequences of labels of different lengths into a uniform tensor. + It ensures that `None` values in the labels are handled by substituting them with a default value(e.g.,`False`). Args: labels (Tuple): Tuple of label rows. From 06ab981dcd018760f6bf94c4a276967b138e538e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 9 Sep 2024 11:43:30 +0200 Subject: [PATCH 27/30] minor changes --- chebai/preprocessing/datasets/base.py | 2 ++ chebai/preprocessing/datasets/chebi.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 22cdb9db..02877ad3 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -155,6 +155,8 @@ def _name(self) -> str: def _filter_labels(self, row: dict) -> dict: """ Filter labels based on `label_filter`. + This method selects specific labels from the `labels` list within the row dictionary + according to the index or indices provided by the `label_filter` attribute of the class. Args: row (dict): A dictionary containing the row data. diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 23f02f68..1c0cb2f9 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -754,7 +754,7 @@ def processed_dir_main(self) -> str: "processed", ) - def extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: """ Extracts a subset of ChEBI based on subclasses of the top class ID. From 62a3f45e59b5d80282926873d2578fde79756203 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 21 Sep 2024 10:48:55 +0200 Subject: [PATCH 28/30] parameter for maximum length (default: 1002) --- chebai/preprocessing/datasets/go_uniprot.py | 36 ++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index cc82793d..319160c6 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -24,6 +24,7 @@ import requests import torch from Bio import SwissProt +from torch.utils.data import DataLoader from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import _DynamicDataset @@ -74,6 +75,11 @@ def __init__(self, **kwargs): self.go_branch: str = self._get_go_branch(**kwargs) super(_GOUniProtDataExtractor, self).__init__(**kwargs) + self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002)) + assert ( + self.max_sequence_length >= 1 + ), "Max sequence length should be greater than or equal to 1." + @classmethod def _get_go_branch(cls, **kwargs) -> str: """ @@ -397,15 +403,14 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame: } # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"} - MAX_LENGTH = 1002 for record in swiss_data: if record.data_class != "Reviewed": # To consider only manually-annotated swiss data continue - if not record.sequence or record.sequence_length > MAX_LENGTH: - # Consider protein with only sequence representation and a maximum length of 1002 + if not record.sequence: + # Consider protein with only sequence representation continue if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence): @@ -524,6 +529,29 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: return df_train, df_val, df_test + # ------------------------------ Phase: DataLoaders ----------------------------------- + def dataloader(self, kind: str, **kwargs) -> DataLoader: + """ + Returns a DataLoader object with truncated sequences for the specified kind of data (train, val, or test). + + This method overrides the dataloader method from the superclass. After fetching the dataset from the + superclass, it truncates the 'features' of each data instance to a maximum length specified by + `self.max_sequence_length`. + + Args: + kind (str): The kind of data to load (e.g., 'train', 'val', 'test'). + **kwargs: Additional keyword arguments passed to the superclass dataloader method. + + Returns: + DataLoader: A DataLoader object with the truncated sequences. + """ + dataloader = super().dataloader(kind, **kwargs) + + # Truncate the 'features' to max_sequence_length for each instance + for instance in dataloader.dataset: + instance["features"] = instance["features"][: self.max_sequence_length] + return dataloader + # ------------------------------ Phase: Raw Properties ----------------------------------- @property def base_dir(self) -> str: @@ -619,7 +647,7 @@ def select_classes( - The `THRESHOLD` attribute, which defines the minimum number of annotations required to select a GO term, should be defined in the subclass. """ # Retrieve the DataFrame containing GO annotations per protein from the keyword arguments - data_df: pd.DataFrame = kwargs.get("data_df", None) + data_df = kwargs.get("data_df", None) if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty: raise AttributeError( "The 'data_df' argument must be provided and must be a non-empty pandas DataFrame." From 6f463dec061e31de07ee917e691244309ed3b236 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 21 Sep 2024 10:53:03 +0200 Subject: [PATCH 29/30] remove label number for GO_UniProt classes --- chebai/preprocessing/datasets/go_uniprot.py | 27 --------------------- 1 file changed, 27 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index 319160c6..c59b3d4a 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -594,11 +594,6 @@ class _GOUniProtOverX(_GOUniProtDataExtractor, ABC): READER: dr.ProteinDataReader = dr.ProteinDataReader THRESHOLD: int = None - @property - @abstractmethod - def label_number(self) -> int: - raise NotImplementedError - @property def _name(self) -> str: """ @@ -706,17 +701,6 @@ class GOUniProtOver250(_GOUniProtOverX): THRESHOLD: int = 250 - def label_number(self) -> int: - """ - Returns the number of labels in the dataset for this threshold. - - Overrides the base class method to provide the correct number of labels for a threshold of 100. - - Returns: - int: The number of labels (854). - """ - return 854 - class GOUniProtOver50(_GOUniProtOverX): """ @@ -729,14 +713,3 @@ class GOUniProtOver50(_GOUniProtOverX): """ THRESHOLD: int = 50 - - def label_number(self) -> int: - """ - Returns the number of labels in the dataset for this threshold. - - Overrides the base class method to provide the correct number of labels for a threshold of 50. - - Returns: - int: The number of labels (1332). - """ - return 1332 From 108d9cab046eceb5d7eeb8179295b2246d801a81 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 21 Sep 2024 16:38:26 +0200 Subject: [PATCH 30/30] trigrams / n-grams combining several amino acids into one token --- chebai/preprocessing/datasets/go_uniprot.py | 10 +++++ chebai/preprocessing/reader.py | 48 ++++++++++++++++----- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py index c59b3d4a..574ecdbd 100644 --- a/chebai/preprocessing/datasets/go_uniprot.py +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -563,6 +563,16 @@ def base_dir(self) -> str: """ return os.path.join("data", f"GO_UniProt") + @property + def identifier(self) -> tuple: + """Identifier for the dataset.""" + # overriding identifier instead of reader.name to keep same tokens.txt file, but different processed_dir folder + if not isinstance(self.reader, dr.ProteinDataReader): + raise ValueError("Need Protein DataReader for identifier") + if self.reader.n_gram is not None: + return (f"{self.reader.name()}_{self.reader.n_gram}_gram",) + return (self.reader.name(),) + @property def raw_file_names_dict(self) -> dict: """ diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index d10185ed..46cd558a 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -382,7 +382,7 @@ def name(cls) -> str: """ return "protein_token" - def __init__(self, *args, **kwargs): + def __init__(self, *args, n_gram: Optional[int] = None, **kwargs): """ Initializes the ProteinDataReader, loading existing tokens from the specified token file. @@ -390,7 +390,16 @@ def __init__(self, *args, **kwargs): *args: Additional positional arguments passed to the base class. **kwargs: Additional keyword arguments passed to the base class. """ + if n_gram is not None: + assert ( + int(n_gram) >= 2 + ), "Ngrams must be greater than or equal to 2 if provided." + self.n_gram = int(n_gram) + else: + self.n_gram = None + super().__init__(*args, **kwargs) + # Load the existing tokens from the token file into a cache with open(self.token_path, "r") as pk: self.cache = [x.strip() for x in pk] @@ -405,14 +414,25 @@ def _get_token_index(self, token: str) -> int: Returns: int: The index of the token, offset by the predefined EMBEDDING_OFFSET. """ - if str(token) not in self.AA_LETTER: - raise KeyError( - f"Invalid token '{token}' encountered. " - f"Please ensure that the input only contains valid amino acids " - f"20 Valid natural amino acid notation: {self.AA_LETTER}" - f"Refer to the amino acid sequence details here: " - f"https://en.wikipedia.org/wiki/Protein_primary_structure" - ) + error_str = ( + f"Please ensure that the input only contains valid amino acids " + f"20 Valid natural amino acid notation: {self.AA_LETTER}" + f"Refer to the amino acid sequence details here: " + f"https://en.wikipedia.org/wiki/Protein_primary_structure" + ) + + if self.n_gram is None: + # Single-letter amino acid token check + if str(token) not in self.AA_LETTER: + raise KeyError(f"Invalid token '{token}' encountered. " + error_str) + else: + # n-gram token validation, ensure that each component of the n-gram is valid + for aa in token: + if aa not in self.AA_LETTER: + raise KeyError( + f"Invalid token '{token}' encountered as part of n-gram {self.n_gram}. " + + error_str + ) if str(token) not in self.cache: self.cache.append(str(token)) @@ -428,7 +448,15 @@ def _read_data(self, raw_data: str) -> List[int]: Returns: List[int]: A list of integers representing the indices of the amino acid tokens. """ - # In the case of protein sequences, each amino acid is typically represented by a single letter. + if self.n_gram is not None: + # Tokenize the sequence into n-grams + tokens = [ + raw_data[i : i + self.n_gram] + for i in range(len(raw_data) - self.n_gram + 1) + ] + return [self._get_token_index(gram) for gram in tokens] + + # If n_gram is None, tokenize the sequence at the amino acid level (single-letter representation) return [self._get_token_index(aa) for aa in raw_data] def on_finish(self) -> None: