Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, *args, **kwargs):
from .medical_transcriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
from .mimic4derived import MIMIC4DerivedDataset
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
from .sample_dataset import SampleDataset
Expand Down
31 changes: 31 additions & 0 deletions pyhealth/datasets/configs/mimic4_derived.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
version: "3.1"
tables:
vasopressorduration:
file_path: "mimic_iv_vasopressor.csv.gz"
patient_id: null
timestamp: "starttime"
attributes:
- "stay_id"
- "vasonum"
- "endtime"
- "duration_hours"

ventduration:
file_path: "mimic_iv_ventilation_duration.csv.gz"
patient_id: null
timestamp: "starttime"
attributes:
- "stay_id"
- "ventnum"
- "endtime"
- "duration_hours"
ventclassification:
file_path: "mimic_iv_ventilation_classification.csv.gz"
patient_id: null
timestamp: "charttime"
attributes:
- "stay_id"
- "MechVent"
- "OxygenTherapy"
- "Extubated"
- "SelfExtubated"
16 changes: 15 additions & 1 deletion pyhealth/datasets/mimic4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd
import polars as pl
from from pyhealth.datasets import MIMIC4DerivedDataset

try:
import psutil
Expand Down Expand Up @@ -203,12 +204,15 @@ def __init__(
ehr_root: Optional[str] = None,
note_root: Optional[str] = None,
cxr_root: Optional[str] = None,
der_root: Optional[str] = None,
ehr_tables: Optional[List[str]] = None,
note_tables: Optional[List[str]] = None,
cxr_tables: Optional[List[str]] = None,
der_tables: Optional[List[str]] = None,
ehr_config_path: Optional[str] = None,
note_config_path: Optional[str] = None,
cxr_config_path: Optional[str] = None,
der_config_path: Optional[str] = None,
dataset_name: str = "mimic4",
dev: bool = False, # Added dev parameter
):
Expand All @@ -223,13 +227,14 @@ def __init__(
self.dev = dev # Store dev mode flag

# We need at least one root directory
if not any([ehr_root, note_root, cxr_root]):
if not any([ehr_root, note_root, cxr_root, der_root]):
raise ValueError("At least one root directory must be provided")

# Initialize empty lists if None provided
ehr_tables = ehr_tables or []
note_tables = note_tables or []
cxr_tables = cxr_tables or []
der_tables = der_tables or []

# Initialize EHR dataset if root is provided
if ehr_root:
Expand Down Expand Up @@ -263,6 +268,15 @@ def __init__(
dev=dev # Pass dev mode flag
)
log_memory_usage("After CXR dataset initialization")
if der_root is not None:
logger.info(f"Initializing MIMIC4DerivedDataset with tables: {der_tables} (dev mode: {dev})")
self.sub_datasets["der"] = MIMIC4DerivedDataset(
root=der_root,
tables=der_tables,
config_path=der_config_path,
dev=dev # Pass dev mode flag
)
log_memory_usage("After Derived dataset initialization")

# Combine data from all sub-datasets
log_memory_usage("Before combining data")
Expand Down
122 changes: 122 additions & 0 deletions pyhealth/datasets/mimic4derived.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import logging
import polars as pl
from pathlib import Path
from typing import Optional
from .base_dataset import BaseDataset

logger = logging.getLogger(__name__)


class MIMIC4DerivedDataset(BaseDataset):
"""Derived Dataset for MIMIC-IV containing ventilation and vasopressor durations

Contributors: Matthew Rimland (NetId: rimland2), Athan Liu (NetId: athanl2)

Inspiration Paper: Racial Disparities and Mistrust in End-of-Life Care (Boag et al. 2018)
Link to paper: https://proceedings.mlr.press/v85/boag18a.html

Dataset is available to be derived from the following link for the Metavision information system:
https://physionet.org/content/mimiciv/3.1/

Transformations derived from the following and adapted for mimic-iv:
https://github.com/MIT-LCP/mimic-code/blob/main/mimic-iii/concepts/durations/ventilation_durations.sql
Eventual official derivation queries can be used to generate expected csv.gz files can also be found at the above github.

Derivation queries used to generate expected data can be found at https://github.com/mrimland/DL4H-Racial-Disparities-And-Mistrust/blob/main/Notebooks/GenerateVentVasoTables.ipynb

Note: The expected tables are available in both MIMIC-III and MIMIC-IV. The configuration in mimic4_derived.yaml contains column names specific to
MIMIC-IV

Args:
root: Root directory of the raw data.
dataset_name: Name of the dataset. Defaults to "mimic4derived".
config_path: Path to the configuration file. If None, uses default config.

Attributes:
root: Root directory of the raw data (should contain many csv files).
dataset_name: Name of the dataset.
config_path: Path to the configuration file.

Examples:
>>> from pyhealth.datasets import MIMIC4DerivedDataset
>>> dataset = MIMIC4DerivedDataset(
... root="path/to/mimic4derived",
... dataset_name="VentData"
... )
>>> dataset.stats()
"""

def __init__(
self,
root: str,
dataset_name: Optional[str] = None,
config_path: Optional[str] = None,
) -> None:
if config_path is None:
logger.info("No config path provided, using default config")
config_path = (
Path(__file__).parent / "configs" / "mimic4_derived.yaml"
)
default_tables = ["vasopressorduration", "ventduration", "ventclassification"]
super().__init__(
root=root,
tables=default_tables,
dataset_name=dataset_name or "mimic4_derived",
config_path=config_path,
)
self.vent_duration = self.filterTable("ventduration")
self.vasopressor_duration = self.filterTable("vasopressorduration")
return

def filterTable(self, table_name):
"""
Helper Function which filters out event_types according to the tableName provided

Args:
table_name - A string indicating the specific event_type to filter on. For now, we only support the default_tables

Returns:
A lazyframe with only columns corresponding to the event_type

Used to extract a particular table from the dataset or in initialization to break up the dataset for later use
"""
df = self.collected_global_event_df
if table_name == "ventduration":
cols = ["stay_id","ventnum","endtime","duration_hours"]
cols = ["ventduration/" + s for s in cols]
return df.filter(pl.col("event_type") == "ventduration").select(["patient_id", "event_type", "timestamp"] + cols)
elif table_name == "ventclassification":
cols = ["stay_id","MechVent","OxygenTherapy","Extubated","SelfExtubated"]
cols = ["ventclassification/" + s for s in cols]
return df.filter(pl.col("event_type") == "ventclassification").select(["patient_id", "event_type", "timestamp"] + cols)
elif table_name == "vasopressorduration":
cols = ["stay_id","vasonum","endtime","duration_hours"]
cols = ["vasopressorduration/" + s for s in cols]
return df.filter(pl.col("event_type") == "vasopressorduration").select(["patient_id", "event_type", "timestamp"] + cols)
else:
logger.error("Unknown table specified")

def stats(self):
df = self.collected_global_event_df
if df.is_empty():
logger.error("Data is not loaded")
return
print("---Vasopressor Duration Statistics---")
vaso_col = self.vasopressor_duration.select(pl.col("vasopressorduration/duration_hours").cast(pl.Int64))
vaso_mean = float(vaso_col.mean()[0,0])
print(f"Mean duration (hrs): {vaso_mean}")
vaso_median = int(vaso_col.median()[0,0])
print(f"Median duration (hrs): {vaso_median}")
vaso_max = int(vaso_col.max()[0,0])
print(f"Max duration (hrs): {vaso_max}")
print("---Ventilation Duration Statistics---")
vent_col = self.vent_duration.select(pl.col("ventduration/duration_hours").cast(pl.Float64))
vent_mean = float(vent_col.mean()[0,0])
print(f"Mean duration (hrs): {vent_mean}")
vent_median = int(vent_col.median()[0,0])
print(f"Median duration (hrs): {vent_median}")
vent_max = float(vent_col.max()[0,0])
print(f"Max duration (hrs): {vent_max}")
print(df.head())