diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index f0e4f53e7..2c498ce01 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -34,6 +34,7 @@ def __init__(self, *args, **kwargs): from .medical_transcriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset +from .mimic4derived import MIMIC4DerivedDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset from .sample_dataset import SampleDataset diff --git a/pyhealth/datasets/configs/mimic4_derived.yaml b/pyhealth/datasets/configs/mimic4_derived.yaml new file mode 100644 index 000000000..8bff53975 --- /dev/null +++ b/pyhealth/datasets/configs/mimic4_derived.yaml @@ -0,0 +1,31 @@ +version: "3.1" +tables: + vasopressorduration: + file_path: "mimic_iv_vasopressor.csv.gz" + patient_id: null + timestamp: "starttime" + attributes: + - "stay_id" + - "vasonum" + - "endtime" + - "duration_hours" + + ventduration: + file_path: "mimic_iv_ventilation_duration.csv.gz" + patient_id: null + timestamp: "starttime" + attributes: + - "stay_id" + - "ventnum" + - "endtime" + - "duration_hours" + ventclassification: + file_path: "mimic_iv_ventilation_classification.csv.gz" + patient_id: null + timestamp: "charttime" + attributes: + - "stay_id" + - "MechVent" + - "OxygenTherapy" + - "Extubated" + - "SelfExtubated" \ No newline at end of file diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 0fda096db..c52a68c14 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -5,6 +5,7 @@ import pandas as pd import polars as pl +from from pyhealth.datasets import MIMIC4DerivedDataset try: import psutil @@ -203,12 +204,15 @@ def __init__( ehr_root: Optional[str] = None, note_root: Optional[str] = None, cxr_root: Optional[str] = None, + der_root: Optional[str] = None, ehr_tables: Optional[List[str]] = None, note_tables: Optional[List[str]] = None, cxr_tables: Optional[List[str]] = None, + der_tables: Optional[List[str]] = None, ehr_config_path: Optional[str] = None, note_config_path: Optional[str] = None, cxr_config_path: Optional[str] = None, + der_config_path: Optional[str] = None, dataset_name: str = "mimic4", dev: bool = False, # Added dev parameter ): @@ -223,13 +227,14 @@ def __init__( self.dev = dev # Store dev mode flag # We need at least one root directory - if not any([ehr_root, note_root, cxr_root]): + if not any([ehr_root, note_root, cxr_root, der_root]): raise ValueError("At least one root directory must be provided") # Initialize empty lists if None provided ehr_tables = ehr_tables or [] note_tables = note_tables or [] cxr_tables = cxr_tables or [] + der_tables = der_tables or [] # Initialize EHR dataset if root is provided if ehr_root: @@ -263,6 +268,15 @@ def __init__( dev=dev # Pass dev mode flag ) log_memory_usage("After CXR dataset initialization") + if der_root is not None: + logger.info(f"Initializing MIMIC4DerivedDataset with tables: {der_tables} (dev mode: {dev})") + self.sub_datasets["der"] = MIMIC4DerivedDataset( + root=der_root, + tables=der_tables, + config_path=der_config_path, + dev=dev # Pass dev mode flag + ) + log_memory_usage("After Derived dataset initialization") # Combine data from all sub-datasets log_memory_usage("Before combining data") diff --git a/pyhealth/datasets/mimic4derived.py b/pyhealth/datasets/mimic4derived.py new file mode 100644 index 000000000..b13d87a58 --- /dev/null +++ b/pyhealth/datasets/mimic4derived.py @@ -0,0 +1,122 @@ +import logging +import polars as pl +from pathlib import Path +from typing import Optional +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class MIMIC4DerivedDataset(BaseDataset): + """Derived Dataset for MIMIC-IV containing ventilation and vasopressor durations + + Contributors: Matthew Rimland (NetId: rimland2), Athan Liu (NetId: athanl2) + + Inspiration Paper: Racial Disparities and Mistrust in End-of-Life Care (Boag et al. 2018) + Link to paper: https://proceedings.mlr.press/v85/boag18a.html + + Dataset is available to be derived from the following link for the Metavision information system: + https://physionet.org/content/mimiciv/3.1/ + + Transformations derived from the following and adapted for mimic-iv: + https://github.com/MIT-LCP/mimic-code/blob/main/mimic-iii/concepts/durations/ventilation_durations.sql + Eventual official derivation queries can be used to generate expected csv.gz files can also be found at the above github. + + Derivation queries used to generate expected data can be found at https://github.com/mrimland/DL4H-Racial-Disparities-And-Mistrust/blob/main/Notebooks/GenerateVentVasoTables.ipynb + + Note: The expected tables are available in both MIMIC-III and MIMIC-IV. The configuration in mimic4_derived.yaml contains column names specific to + MIMIC-IV + + Args: + root: Root directory of the raw data. + dataset_name: Name of the dataset. Defaults to "mimic4derived". + config_path: Path to the configuration file. If None, uses default config. + + Attributes: + root: Root directory of the raw data (should contain many csv files). + dataset_name: Name of the dataset. + config_path: Path to the configuration file. + + Examples: + >>> from pyhealth.datasets import MIMIC4DerivedDataset + >>> dataset = MIMIC4DerivedDataset( + ... root="path/to/mimic4derived", + ... dataset_name="VentData" + ... ) + >>> dataset.stats() + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + ) -> None: + if config_path is None: + logger.info("No config path provided, using default config") + config_path = ( + Path(__file__).parent / "configs" / "mimic4_derived.yaml" + ) + default_tables = ["vasopressorduration", "ventduration", "ventclassification"] + super().__init__( + root=root, + tables=default_tables, + dataset_name=dataset_name or "mimic4_derived", + config_path=config_path, + ) + self.vent_duration = self.filterTable("ventduration") + self.vasopressor_duration = self.filterTable("vasopressorduration") + return + + def filterTable(self, table_name): + """ + Helper Function which filters out event_types according to the tableName provided + + Args: + table_name - A string indicating the specific event_type to filter on. For now, we only support the default_tables + + Returns: + A lazyframe with only columns corresponding to the event_type + + Used to extract a particular table from the dataset or in initialization to break up the dataset for later use + """ + df = self.collected_global_event_df + if table_name == "ventduration": + cols = ["stay_id","ventnum","endtime","duration_hours"] + cols = ["ventduration/" + s for s in cols] + return df.filter(pl.col("event_type") == "ventduration").select(["patient_id", "event_type", "timestamp"] + cols) + elif table_name == "ventclassification": + cols = ["stay_id","MechVent","OxygenTherapy","Extubated","SelfExtubated"] + cols = ["ventclassification/" + s for s in cols] + return df.filter(pl.col("event_type") == "ventclassification").select(["patient_id", "event_type", "timestamp"] + cols) + elif table_name == "vasopressorduration": + cols = ["stay_id","vasonum","endtime","duration_hours"] + cols = ["vasopressorduration/" + s for s in cols] + return df.filter(pl.col("event_type") == "vasopressorduration").select(["patient_id", "event_type", "timestamp"] + cols) + else: + logger.error("Unknown table specified") + + def stats(self): + df = self.collected_global_event_df + if df.is_empty(): + logger.error("Data is not loaded") + return + print("---Vasopressor Duration Statistics---") + vaso_col = self.vasopressor_duration.select(pl.col("vasopressorduration/duration_hours").cast(pl.Int64)) + vaso_mean = float(vaso_col.mean()[0,0]) + print(f"Mean duration (hrs): {vaso_mean}") + vaso_median = int(vaso_col.median()[0,0]) + print(f"Median duration (hrs): {vaso_median}") + vaso_max = int(vaso_col.max()[0,0]) + print(f"Max duration (hrs): {vaso_max}") + print("---Ventilation Duration Statistics---") + vent_col = self.vent_duration.select(pl.col("ventduration/duration_hours").cast(pl.Float64)) + vent_mean = float(vent_col.mean()[0,0]) + print(f"Mean duration (hrs): {vent_mean}") + vent_median = int(vent_col.median()[0,0]) + print(f"Median duration (hrs): {vent_median}") + vent_max = float(vent_col.max()[0,0]) + print(f"Max duration (hrs): {vent_max}") + print(df.head()) + +