Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Available Datasets
datasets/pyhealth.datasets.SHHSDataset
datasets/pyhealth.datasets.SleepEDFDataset
datasets/pyhealth.datasets.EHRShotDataset
datasets/pyhealth.datasets.Support2Dataset
datasets/pyhealth.datasets.BMDHSDataset
datasets/pyhealth.datasets.COVID19CXRDataset
datasets/pyhealth.datasets.ChestXray14Dataset
Expand Down
15 changes: 15 additions & 0 deletions docs/api/datasets/pyhealth.datasets.Support2Dataset.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pyhealth.datasets.Support2Dataset
==================================

Overview
--------

The SUPPORT2 (Study to Understand Prognoses and Preferences for Outcomes and Risks of Treatments) dataset contains data on seriously ill hospitalized adults. It includes patient demographics, diagnoses, clinical measurements, and outcomes such as survival and hospital mortality.

The dataset is commonly used for mortality prediction, length of stay prediction, and other clinical outcome prediction tasks.

.. autoclass:: pyhealth.datasets.Support2Dataset
:members:
:undoc-members:
:show-inheritance:

1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, *args, **kwargs):
from .shhs import SHHSDataset
from .sleepedf import SleepEDFDataset
from .bmd_hs import BMDHSDataset
from .support2 import Support2Dataset
from .splitter import (
split_by_patient,
split_by_patient_conformal,
Expand Down
55 changes: 55 additions & 0 deletions pyhealth/datasets/configs/support2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
version: "1.0"
tables:
support2:
file_path: "support2.csv"
patient_id: "sno"
timestamp: null
attributes:
- "d.time"
- "age"
- "death"
- "sex"
- "hospdead"
- "slos"
- "dzgroup"
- "dzclass"
- "num.co"
- "edu"
- "income"
- "scoma"
- "charges"
- "totcst"
- "totmcst"
- "avtisst"
- "race"
- "sps"
- "aps"
- "surv2m"
- "surv6m"
- "hday"
- "diabetes"
- "dementia"
- "ca"
- "prg2m"
- "prg6m"
- "dnr"
- "dnrday"
- "meanbp"
- "wblc"
- "hrt"
- "resp"
- "temp"
- "pafi"
- "alb"
- "bili"
- "crea"
- "sod"
- "ph"
- "glucose"
- "bun"
- "urine"
- "adlp"
- "adls"
- "sfdm2"
- "adlsc"

72 changes: 72 additions & 0 deletions pyhealth/datasets/support2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import logging
from pathlib import Path
from typing import List, Optional

from .base_dataset import BaseDataset

logger = logging.getLogger(__name__)


class Support2Dataset(BaseDataset):
"""
A dataset class for handling SUPPORT2 (Study to Understand Prognoses and Preferences
for Outcomes and Risks of Treatments) data.

The SUPPORT2 dataset contains data on 9,105 seriously ill hospitalized adults from
five U.S. medical centers (1989-1994), including patient demographics, diagnoses,
clinical measurements, and outcomes.

Dataset is available for download from:
- UCI Machine Learning Repository: https://archive.ics.uci.edu/dataset/880/support2
- Vanderbilt Biostatistics: https://hbiostat.org/data/repo/supportdesc
- Hugging Face: https://huggingface.co/datasets/jarrydmartinx/support2
- R packages: "rms" and "Hmisc"

Citation:
Knaus WA, Harrell FE, Lynn J, et al. The SUPPORT prognostic model:
Objective estimates of survival for seriously ill hospitalized adults.
Ann Intern Med. 1995;122(3):191-203.

Args:
root (str): The root directory where the dataset CSV file is stored.
tables (List[str]): A list of tables to be included (typically ["support2"]).
dataset_name (Optional[str]): The name of the dataset. Defaults to "support2".
config_path (Optional[str]): The path to the configuration file. If not provided,
uses the default config.
**kwargs: Additional arguments passed to BaseDataset.

Examples:
>>> from pyhealth.datasets import Support2Dataset
>>> dataset = Support2Dataset(
... root="/path/to/support2/data",
... tables=["support2"]
... )
>>> dataset.stats()

Attributes:
root (str): The root directory where the dataset is stored.
tables (List[str]): A list of tables to be included in the dataset.
dataset_name (str): The name of the dataset.
config_path (str): The path to the configuration file.
"""

def __init__(
self,
root: str,
tables: List[str],
dataset_name: Optional[str] = None,
config_path: Optional[str] = None,
**kwargs
) -> None:
if config_path is None:
logger.info("No config path provided, using default config")
config_path = Path(__file__).parent / "configs" / "support2.yaml"
super().__init__(
root=root,
tables=tables,
dataset_name=dataset_name or "support2",
config_path=config_path,
**kwargs
)
return

4 changes: 4 additions & 0 deletions test-resources/core/support2/support2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
sno,age,death,sex,hospdead,slos,d.time,dzgroup,dzclass,num.co,edu,income,scoma,charges,totcst,totmcst,avtisst,race,sps,aps,surv2m,surv6m,hday,diabetes,dementia,ca,prg2m,prg6m,dnr,dnrday,meanbp,wblc,hrt,resp,temp,pafi,alb,bili,crea,sod,ph,glucose,bun,urine,adlp,adls,sfdm2,adlsc
1,62.85,0,male,0,5,2029,Lung Cancer,Cancer,0,11,$11-$25k,0,9715.0,,,7.0,other,33.9,20,0.26,0.037,1,0,0,metastatic,0.5,0.25,no dnr,5,97.0,6.0,69.0,22,36.0,388.0,1.8,0.2,1.2,141,7.46,,,,7,7,,7.0
2,60.34,1,female,1,4,4,Cirrhosis,COPD/CHF/Cirrhosis,2,12,$11-$25k,44,34496.0,,,29.0,white,52.7,74,0.001,0.0,3,0,0,no,0.0,0.0,,,43.0,17.1,112.0,34,34.6,98.0,,,5.5,132,7.25,,,,1,<2 mo. follow-up,1.0
3,52.75,1,female,0,17,47,Cirrhosis,COPD/CHF/Cirrhosis,2,12,under $11k,0,41094.0,,,13.0,white,20.5,45,0.79,0.66,4,0,0,no,0.75,0.5,no dnr,17,70.0,8.5,88.0,28,37.4,231.7,,2.2,2.0,134,7.46,,,,1,0,<2 mo. follow-up,0.0
125 changes: 125 additions & 0 deletions tests/core/test_support2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import unittest
import tempfile
import shutil
from pathlib import Path
import pandas as pd

from pyhealth.datasets import Support2Dataset

class TestSupport2Dataset(unittest.TestCase):
"""Test Support2Dataset with synthetic test data."""

def setUp(self):
"""Set up test data files and directory structure with synthetic data."""
# Create temporary directory with synthetic data
self.temp_dir = tempfile.mkdtemp()
self.root = Path(self.temp_dir)

# Create minimal synthetic support2.csv with 3 patients
# representing the structure of the SUPPORT2 dataset
support2_data = {
'sno': ['1', '2', '3'],
'age': [62.85, 60.34, 52.75],
'death': [0, 1, 1],
'sex': ['male', 'female', 'female'],
'hospdead': [0, 1, 0],
'slos': [5, 4, 17],
'd.time': [2029, 4, 47],
'dzgroup': ['Lung Cancer', 'Cirrhosis', 'Cirrhosis'],
'dzclass': ['Cancer', 'COPD/CHF/Cirrhosis', 'COPD/CHF/Cirrhosis'],
'num.co': [0, 2, 2],
'edu': [11, 12, 12],
'income': ['$11-$25k', '$11-$25k', 'under $11k'],
'scoma': [0, 44, 0],
'charges': [9715.0, 34496.0, 41094.0],
'race': ['other', 'white', 'white'],
'sps': [33.9, 52.7, 20.5],
'aps': [20, 74, 45],
'surv2m': [0.26, 0.001, 0.79],
'surv6m': [0.037, 0.0, 0.66],
'hday': [1, 3, 4],
'diabetes': [0, 0, 0],
'dementia': [0, 0, 0],
'ca': ['metastatic', 'no', 'no'],
'meanbp': [97.0, 43.0, 70.0],
'wblc': [6.0, 17.1, 8.5],
'hrt': [69.0, 112.0, 88.0],
'resp': [22, 34, 28],
'temp': [36.0, 34.6, 37.4],
'pafi': [388.0, 98.0, 231.7],
'alb': [1.8, None, None],
'bili': [0.2, None, 2.2],
'crea': [1.2, 5.5, 2.0],
'sod': [141, 132, 134],
'ph': [7.46, 7.25, 7.46],
'glucose': [None, None, None],
'bun': [None, None, None],
'urine': [None, None, None],
'adlp': [7, 1, 1],
'adls': [7, None, 0],
'sfdm2': [None, '<2 mo. follow-up', '<2 mo. follow-up'],
'adlsc': [7.0, 1.0, 0.0],
'prg2m': [0.5, 0.0, 0.75],
'prg6m': [0.25, 0.0, 0.5],
'dnr': ['no dnr', None, 'no dnr'],
'dnrday': [5, None, 17],
'totcst': [None, None, None],
'totmcst': [None, None, None],
'avtisst': [7.0, 29.0, 13.0]
}

support2_df = pd.DataFrame(support2_data)
support2_df.to_csv(self.root / "support2.csv", index=False)

def tearDown(self):
"""Clean up temporary files."""
shutil.rmtree(self.temp_dir)

def test_dataset_initialization(self):
"""Test Support2Dataset initialization."""
dataset = Support2Dataset(
root=str(self.root),
tables=["support2"]
)
self.assertIsInstance(dataset, Support2Dataset)
self.assertEqual(dataset.dataset_name, "support2")

def test_load_data(self):
"""Test that data loads correctly."""
dataset = Support2Dataset(
root=str(self.root),
tables=["support2"]
)
self.assertIsNotNone(dataset.global_event_df)

def test_patient_count(self):
"""Test that the dataset contains the expected number of patients."""
dataset = Support2Dataset(
root=str(self.root),
tables=["support2"]
)
unique_patients = dataset.unique_patient_ids
self.assertEqual(len(unique_patients), 3)

def test_get_patient(self):
"""Test retrieving a single patient by ID."""
dataset = Support2Dataset(
root=str(self.root),
tables=["support2"]
)
patient = dataset.get_patient("1")
self.assertIsNotNone(patient)
self.assertEqual(patient.patient_id, "1")

def test_stats(self):
"""Test that stats method executes without errors."""
dataset = Support2Dataset(
root=str(self.root),
tables=["support2"]
)
dataset.stats()


if __name__ == "__main__":
unittest.main()