Skip to content

Commit 4fdeb08

Browse files
authored
Add dataset SUPPORT2 (#614)
* Add dataset SUPPORT2 * Update test dataset under test-resources; add link to full dataset * Add a preprocessing task for dataset SUPPORT2 * Enhance documentation for Survival Preprocess task and add detailed example usage in tutorials
1 parent 028208d commit 4fdeb08

File tree

11 files changed

+840
-0
lines changed

11 files changed

+840
-0
lines changed

docs/api/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Available Datasets
4242
datasets/pyhealth.datasets.SHHSDataset
4343
datasets/pyhealth.datasets.SleepEDFDataset
4444
datasets/pyhealth.datasets.EHRShotDataset
45+
datasets/pyhealth.datasets.Support2Dataset
4546
datasets/pyhealth.datasets.BMDHSDataset
4647
datasets/pyhealth.datasets.COVID19CXRDataset
4748
datasets/pyhealth.datasets.ChestXray14Dataset
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
pyhealth.datasets.Support2Dataset
2+
==================================
3+
4+
Overview
5+
--------
6+
7+
The SUPPORT2 (Study to Understand Prognoses and Preferences for Outcomes and Risks of Treatments) dataset contains data on seriously ill hospitalized adults. It includes patient demographics, diagnoses, clinical measurements, and outcomes such as survival and hospital mortality.
8+
9+
The dataset is commonly used for mortality prediction, length of stay prediction, and other clinical outcome prediction tasks.
10+
11+
.. autoclass:: pyhealth.datasets.Support2Dataset
12+
:members:
13+
:undoc-members:
14+
:show-inheritance:
15+

docs/tutorials.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,18 @@ Readmission Prediction
9999
* - ``readmission_mimic3_fairness.py``
100100
- Fairness-aware readmission prediction on MIMIC-III
101101

102+
Survival Prediction
103+
-------------------
104+
105+
.. list-table::
106+
:widths: 50 50
107+
:header-rows: 1
108+
109+
* - Example File
110+
- Description
111+
* - ``survival_preprocess_support2_demo.py``
112+
- Survival probability prediction preprocessing with SUPPORT2 dataset. Demonstrates feature extraction (demographics, vitals, labs, scores, comorbidities) and ground truth survival probability labels for 2-month and 6-month horizons. Shows how to decode processed tensors back to human-readable features.
113+
102114
Drug Recommendation
103115
-------------------
104116

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
Demo script for survival prediction preprocessing using SUPPORT2 dataset.
3+
4+
This example demonstrates how to:
5+
1. Load the SUPPORT2 dataset (using test data with 3 patients)
6+
2. Apply the preprocessing task to extract features and labels
7+
3. Examine preprocessed samples ready for model training
8+
9+
The preprocessing task extracts:
10+
- Features from raw patient data (demographics, vitals, labs, scores, etc.)
11+
- Ground truth survival probabilities from the dataset (surv2m/surv6m fields)
12+
- Structures data into samples ready for training a prediction model
13+
14+
Note: The survival probabilities shown are ground truth labels extracted from the
15+
dataset (surv2m/surv6m columns). These are the target variables that a model
16+
would learn to predict from the extracted features.
17+
18+
This example uses the synthetic test dataset from test-resources (3 patients).
19+
For real usage, replace the path with your actual SUPPORT2 dataset.
20+
"""
21+
22+
import warnings
23+
import logging
24+
from pathlib import Path
25+
26+
# Suppress warnings and reduce logging verbosity
27+
warnings.filterwarnings("ignore")
28+
logging.basicConfig(level=logging.WARNING)
29+
logging.getLogger("pyhealth").setLevel(logging.WARNING)
30+
logging.getLogger("pyhealth.datasets").setLevel(logging.WARNING)
31+
logging.getLogger("pyhealth.datasets.support2").setLevel(logging.WARNING)
32+
logging.getLogger("pyhealth.datasets.base_dataset").setLevel(logging.WARNING)
33+
34+
# Import pyhealth modules
35+
from pyhealth.datasets import Support2Dataset
36+
from pyhealth.tasks import SurvivalPreprocessSupport2
37+
38+
# Suppress tqdm progress bars for cleaner output
39+
try:
40+
def noop_tqdm(iterable, *args, **kwargs):
41+
return iterable
42+
from pyhealth.datasets import base_dataset, sample_dataset
43+
base_dataset.tqdm = noop_tqdm
44+
sample_dataset.tqdm = noop_tqdm
45+
import tqdm
46+
tqdm.tqdm = noop_tqdm
47+
except (ImportError, AttributeError):
48+
pass
49+
50+
# Step 1: Load dataset using test data
51+
print("=" * 70)
52+
print("Step 1: Load SUPPORT2 Dataset")
53+
print("=" * 70)
54+
script_dir = Path(__file__).parent
55+
test_data_path = script_dir.parent / "test-resources" / "core" / "support2"
56+
57+
dataset = Support2Dataset(
58+
root=str(test_data_path),
59+
tables=["support2"],
60+
)
61+
62+
print(f"Loaded dataset with {len(dataset.unique_patient_ids)} patients\n")
63+
64+
# Step 2: Apply preprocessing task to extract features and labels (2-month horizon)
65+
print("=" * 70)
66+
print("Step 2: Apply Survival Preprocessing Task")
67+
print("=" * 70)
68+
task = SurvivalPreprocessSupport2(time_horizon="2m")
69+
sample_dataset = dataset.set_task(task=task)
70+
71+
print(f"Generated {len(sample_dataset)} samples")
72+
print(f"Input schema: {sample_dataset.input_schema}")
73+
print(f"Output schema: {sample_dataset.output_schema}\n")
74+
75+
# Helper function to decode tensor indices to feature strings
76+
def decode_features(tensor, processor):
77+
"""Decode tensor indices back to original feature strings."""
78+
if processor is None or not hasattr(processor, 'code_vocab'):
79+
return [str(idx.item()) for idx in tensor]
80+
reverse_vocab = {idx: token for token, idx in processor.code_vocab.items()}
81+
return [reverse_vocab.get(idx.item(), f"<unk:{idx.item()}>") for idx in tensor]
82+
83+
# Step 3: Display features for all samples
84+
print("=" * 70)
85+
print("Step 3: Examine Preprocessed Samples")
86+
print("=" * 70)
87+
# Sort samples by patient_id to ensure consistent order
88+
samples = sorted(sample_dataset, key=lambda x: int(x['patient_id']))
89+
for sample in samples:
90+
# Display patient ID and tensor shapes first
91+
print(f"\nPatient {sample['patient_id']}:")
92+
print(f" Demographics tensor shape: {sample['demographics'].shape}")
93+
print(f" Disease codes tensor shape: {sample['disease_codes'].shape}")
94+
print(f" Vitals tensor shape: {sample['vitals'].shape}")
95+
print(f" Labs tensor shape: {sample['labs'].shape}")
96+
print(f" Scores tensor shape: {sample['scores'].shape}")
97+
print(f" Comorbidities tensor shape: {sample['comorbidities'].shape}")
98+
99+
# Decode and display features for this sample
100+
demographics = decode_features(
101+
sample['demographics'],
102+
sample_dataset.input_processors.get('demographics')
103+
)
104+
disease_codes = decode_features(
105+
sample['disease_codes'],
106+
sample_dataset.input_processors.get('disease_codes')
107+
)
108+
vitals = decode_features(
109+
sample['vitals'],
110+
sample_dataset.input_processors.get('vitals')
111+
)
112+
labs = decode_features(
113+
sample['labs'],
114+
sample_dataset.input_processors.get('labs')
115+
)
116+
scores = decode_features(
117+
sample['scores'],
118+
sample_dataset.input_processors.get('scores')
119+
)
120+
comorbidities = decode_features(
121+
sample['comorbidities'],
122+
sample_dataset.input_processors.get('comorbidities')
123+
)
124+
125+
# Display decoded features
126+
print(f" Demographics: {', '.join(demographics)}")
127+
print(f" Disease Codes: {', '.join(disease_codes)}")
128+
print(f" Vitals: {', '.join(vitals)}")
129+
print(f" Labs: {', '.join(labs)}")
130+
print(f" Scores: {', '.join(scores)}")
131+
print(f" Comorbidities: {', '.join(comorbidities)}")
132+
print(f" Survival Probability (2m): {sample['survival_probability'].item():.4f}")
133+
134+
print("\n")
135+
print("=" * 70)
136+
print("Preprocessing Complete!")
137+
print("=" * 70)
138+
print("The samples are ready for model training.")

pyhealth/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(self, *args, **kwargs):
6363
from .shhs import SHHSDataset
6464
from .sleepedf import SleepEDFDataset
6565
from .bmd_hs import BMDHSDataset
66+
from .support2 import Support2Dataset
6667
from .splitter import (
6768
split_by_patient,
6869
split_by_patient_conformal,
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
version: "1.0"
2+
tables:
3+
support2:
4+
file_path: "support2.csv"
5+
patient_id: "sno"
6+
timestamp: null
7+
attributes:
8+
- "d.time"
9+
- "age"
10+
- "death"
11+
- "sex"
12+
- "hospdead"
13+
- "slos"
14+
- "dzgroup"
15+
- "dzclass"
16+
- "num.co"
17+
- "edu"
18+
- "income"
19+
- "scoma"
20+
- "charges"
21+
- "totcst"
22+
- "totmcst"
23+
- "avtisst"
24+
- "race"
25+
- "sps"
26+
- "aps"
27+
- "surv2m"
28+
- "surv6m"
29+
- "hday"
30+
- "diabetes"
31+
- "dementia"
32+
- "ca"
33+
- "prg2m"
34+
- "prg6m"
35+
- "dnr"
36+
- "dnrday"
37+
- "meanbp"
38+
- "wblc"
39+
- "hrt"
40+
- "resp"
41+
- "temp"
42+
- "pafi"
43+
- "alb"
44+
- "bili"
45+
- "crea"
46+
- "sod"
47+
- "ph"
48+
- "glucose"
49+
- "bun"
50+
- "urine"
51+
- "adlp"
52+
- "adls"
53+
- "sfdm2"
54+
- "adlsc"
55+

pyhealth/datasets/support2.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import logging
2+
from pathlib import Path
3+
from typing import List, Optional
4+
5+
from .base_dataset import BaseDataset
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class Support2Dataset(BaseDataset):
11+
"""
12+
A dataset class for handling SUPPORT2 (Study to Understand Prognoses and Preferences
13+
for Outcomes and Risks of Treatments) data.
14+
15+
The SUPPORT2 dataset contains data on 9,105 seriously ill hospitalized adults from
16+
five U.S. medical centers (1989-1994), including patient demographics, diagnoses,
17+
clinical measurements, and outcomes.
18+
19+
Dataset is available for download from:
20+
- UCI Machine Learning Repository: https://archive.ics.uci.edu/dataset/880/support2
21+
- Vanderbilt Biostatistics: https://hbiostat.org/data/repo/supportdesc
22+
- Hugging Face: https://huggingface.co/datasets/jarrydmartinx/support2
23+
- R packages: "rms" and "Hmisc"
24+
25+
Citation:
26+
Knaus WA, Harrell FE, Lynn J, et al. The SUPPORT prognostic model:
27+
Objective estimates of survival for seriously ill hospitalized adults.
28+
Ann Intern Med. 1995;122(3):191-203.
29+
30+
Args:
31+
root (str): The root directory where the dataset CSV file is stored.
32+
tables (List[str]): A list of tables to be included (typically ["support2"]).
33+
dataset_name (Optional[str]): The name of the dataset. Defaults to "support2".
34+
config_path (Optional[str]): The path to the configuration file. If not provided,
35+
uses the default config.
36+
**kwargs: Additional arguments passed to BaseDataset.
37+
38+
Examples:
39+
>>> from pyhealth.datasets import Support2Dataset
40+
>>> dataset = Support2Dataset(
41+
... root="/path/to/support2/data",
42+
... tables=["support2"]
43+
... )
44+
>>> dataset.stats()
45+
46+
Attributes:
47+
root (str): The root directory where the dataset is stored.
48+
tables (List[str]): A list of tables to be included in the dataset.
49+
dataset_name (str): The name of the dataset.
50+
config_path (str): The path to the configuration file.
51+
"""
52+
53+
def __init__(
54+
self,
55+
root: str,
56+
tables: List[str],
57+
dataset_name: Optional[str] = None,
58+
config_path: Optional[str] = None,
59+
**kwargs
60+
) -> None:
61+
if config_path is None:
62+
logger.info("No config path provided, using default config")
63+
config_path = Path(__file__).parent / "configs" / "support2.yaml"
64+
super().__init__(
65+
root=root,
66+
tables=tables,
67+
dataset_name=dataset_name or "support2",
68+
config_path=config_path,
69+
**kwargs
70+
)
71+
return
72+

pyhealth/tasks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
MortalityPredictionMIMIC4,
3838
MortalityPredictionOMOP,
3939
)
40+
from .survival_preprocess_support2 import SurvivalPreprocessSupport2
4041
from .mortality_prediction_stagenet_mimic4 import (
4142
MortalityPredictionStageNetMIMIC4,
4243
)

0 commit comments

Comments
 (0)