Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Available Datasets
datasets/pyhealth.datasets.SHHSDataset
datasets/pyhealth.datasets.SleepEDFDataset
datasets/pyhealth.datasets.EHRShotDataset
datasets/pyhealth.datasets.Support2Dataset
datasets/pyhealth.datasets.BMDHSDataset
datasets/pyhealth.datasets.COVID19CXRDataset
datasets/pyhealth.datasets.ChestXray14Dataset
Expand Down
15 changes: 15 additions & 0 deletions docs/api/datasets/pyhealth.datasets.Support2Dataset.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pyhealth.datasets.Support2Dataset
==================================

Overview
--------

The SUPPORT2 (Study to Understand Prognoses and Preferences for Outcomes and Risks of Treatments) dataset contains data on seriously ill hospitalized adults. It includes patient demographics, diagnoses, clinical measurements, and outcomes such as survival and hospital mortality.

The dataset is commonly used for mortality prediction, length of stay prediction, and other clinical outcome prediction tasks.

.. autoclass:: pyhealth.datasets.Support2Dataset
:members:
:undoc-members:
:show-inheritance:

12 changes: 12 additions & 0 deletions docs/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ Readmission Prediction
* - ``readmission_mimic3_fairness.py``
- Fairness-aware readmission prediction on MIMIC-III

Survival Prediction
-------------------

.. list-table::
:widths: 50 50
:header-rows: 1

* - Example File
- Description
* - ``survival_preprocess_support2_demo.py``
- Survival probability prediction preprocessing with SUPPORT2 dataset. Demonstrates feature extraction (demographics, vitals, labs, scores, comorbidities) and ground truth survival probability labels for 2-month and 6-month horizons. Shows how to decode processed tensors back to human-readable features.

Drug Recommendation
-------------------

Expand Down
138 changes: 138 additions & 0 deletions examples/survival_preprocess_support2_demo.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Can you change your example to use the PyHealth task here you wrote? To make it more complete?
  2. Can you also update the docs to point to this example here? Just in case people want to see how to use it? I will probably do another refactor/aggregate in the tutorials/additional_examples.rst here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

  1. Updated example in the task file
  2. Updated docstring in the task file to point to the example demo file
  3. Updated tutorials/additional_examples.rst to also point to the example demo file

Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Demo script for survival prediction preprocessing using SUPPORT2 dataset.

This example demonstrates how to:
1. Load the SUPPORT2 dataset (using test data with 3 patients)
2. Apply the preprocessing task to extract features and labels
3. Examine preprocessed samples ready for model training

The preprocessing task extracts:
- Features from raw patient data (demographics, vitals, labs, scores, etc.)
- Ground truth survival probabilities from the dataset (surv2m/surv6m fields)
- Structures data into samples ready for training a prediction model

Note: The survival probabilities shown are ground truth labels extracted from the
dataset (surv2m/surv6m columns). These are the target variables that a model
would learn to predict from the extracted features.

This example uses the synthetic test dataset from test-resources (3 patients).
For real usage, replace the path with your actual SUPPORT2 dataset.
"""

import warnings
import logging
from pathlib import Path

# Suppress warnings and reduce logging verbosity
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.WARNING)
logging.getLogger("pyhealth").setLevel(logging.WARNING)
logging.getLogger("pyhealth.datasets").setLevel(logging.WARNING)
logging.getLogger("pyhealth.datasets.support2").setLevel(logging.WARNING)
logging.getLogger("pyhealth.datasets.base_dataset").setLevel(logging.WARNING)

# Import pyhealth modules
from pyhealth.datasets import Support2Dataset
from pyhealth.tasks import SurvivalPreprocessSupport2

# Suppress tqdm progress bars for cleaner output
try:
def noop_tqdm(iterable, *args, **kwargs):
return iterable
from pyhealth.datasets import base_dataset, sample_dataset
base_dataset.tqdm = noop_tqdm
sample_dataset.tqdm = noop_tqdm
import tqdm
tqdm.tqdm = noop_tqdm
except (ImportError, AttributeError):
pass

# Step 1: Load dataset using test data
print("=" * 70)
print("Step 1: Load SUPPORT2 Dataset")
print("=" * 70)
script_dir = Path(__file__).parent
test_data_path = script_dir.parent / "test-resources" / "core" / "support2"

dataset = Support2Dataset(
root=str(test_data_path),
tables=["support2"],
)

print(f"Loaded dataset with {len(dataset.unique_patient_ids)} patients\n")

# Step 2: Apply preprocessing task to extract features and labels (2-month horizon)
print("=" * 70)
print("Step 2: Apply Survival Preprocessing Task")
print("=" * 70)
task = SurvivalPreprocessSupport2(time_horizon="2m")
sample_dataset = dataset.set_task(task=task)

print(f"Generated {len(sample_dataset)} samples")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}\n")

# Helper function to decode tensor indices to feature strings
def decode_features(tensor, processor):
"""Decode tensor indices back to original feature strings."""
if processor is None or not hasattr(processor, 'code_vocab'):
return [str(idx.item()) for idx in tensor]
reverse_vocab = {idx: token for token, idx in processor.code_vocab.items()}
return [reverse_vocab.get(idx.item(), f"<unk:{idx.item()}>") for idx in tensor]

# Step 3: Display features for all samples
print("=" * 70)
print("Step 3: Examine Preprocessed Samples")
print("=" * 70)
# Sort samples by patient_id to ensure consistent order
samples = sorted(sample_dataset, key=lambda x: int(x['patient_id']))
for sample in samples:
# Display patient ID and tensor shapes first
print(f"\nPatient {sample['patient_id']}:")
print(f" Demographics tensor shape: {sample['demographics'].shape}")
print(f" Disease codes tensor shape: {sample['disease_codes'].shape}")
print(f" Vitals tensor shape: {sample['vitals'].shape}")
print(f" Labs tensor shape: {sample['labs'].shape}")
print(f" Scores tensor shape: {sample['scores'].shape}")
print(f" Comorbidities tensor shape: {sample['comorbidities'].shape}")

# Decode and display features for this sample
demographics = decode_features(
sample['demographics'],
sample_dataset.input_processors.get('demographics')
)
disease_codes = decode_features(
sample['disease_codes'],
sample_dataset.input_processors.get('disease_codes')
)
vitals = decode_features(
sample['vitals'],
sample_dataset.input_processors.get('vitals')
)
labs = decode_features(
sample['labs'],
sample_dataset.input_processors.get('labs')
)
scores = decode_features(
sample['scores'],
sample_dataset.input_processors.get('scores')
)
comorbidities = decode_features(
sample['comorbidities'],
sample_dataset.input_processors.get('comorbidities')
)

# Display decoded features
print(f" Demographics: {', '.join(demographics)}")
print(f" Disease Codes: {', '.join(disease_codes)}")
print(f" Vitals: {', '.join(vitals)}")
print(f" Labs: {', '.join(labs)}")
print(f" Scores: {', '.join(scores)}")
print(f" Comorbidities: {', '.join(comorbidities)}")
print(f" Survival Probability (2m): {sample['survival_probability'].item():.4f}")

print("\n")
print("=" * 70)
print("Preprocessing Complete!")
print("=" * 70)
print("The samples are ready for model training.")
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, *args, **kwargs):
from .shhs import SHHSDataset
from .sleepedf import SleepEDFDataset
from .bmd_hs import BMDHSDataset
from .support2 import Support2Dataset
from .splitter import (
split_by_patient,
split_by_patient_conformal,
Expand Down
55 changes: 55 additions & 0 deletions pyhealth/datasets/configs/support2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
version: "1.0"
tables:
support2:
file_path: "support2.csv"
patient_id: "sno"
timestamp: null
attributes:
- "d.time"
- "age"
- "death"
- "sex"
- "hospdead"
- "slos"
- "dzgroup"
- "dzclass"
- "num.co"
- "edu"
- "income"
- "scoma"
- "charges"
- "totcst"
- "totmcst"
- "avtisst"
- "race"
- "sps"
- "aps"
- "surv2m"
- "surv6m"
- "hday"
- "diabetes"
- "dementia"
- "ca"
- "prg2m"
- "prg6m"
- "dnr"
- "dnrday"
- "meanbp"
- "wblc"
- "hrt"
- "resp"
- "temp"
- "pafi"
- "alb"
- "bili"
- "crea"
- "sod"
- "ph"
- "glucose"
- "bun"
- "urine"
- "adlp"
- "adls"
- "sfdm2"
- "adlsc"

72 changes: 72 additions & 0 deletions pyhealth/datasets/support2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import logging
from pathlib import Path
from typing import List, Optional

from .base_dataset import BaseDataset

logger = logging.getLogger(__name__)


class Support2Dataset(BaseDataset):
"""
A dataset class for handling SUPPORT2 (Study to Understand Prognoses and Preferences
for Outcomes and Risks of Treatments) data.

The SUPPORT2 dataset contains data on 9,105 seriously ill hospitalized adults from
five U.S. medical centers (1989-1994), including patient demographics, diagnoses,
clinical measurements, and outcomes.

Dataset is available for download from:
- UCI Machine Learning Repository: https://archive.ics.uci.edu/dataset/880/support2
- Vanderbilt Biostatistics: https://hbiostat.org/data/repo/supportdesc
- Hugging Face: https://huggingface.co/datasets/jarrydmartinx/support2
- R packages: "rms" and "Hmisc"

Citation:
Knaus WA, Harrell FE, Lynn J, et al. The SUPPORT prognostic model:
Objective estimates of survival for seriously ill hospitalized adults.
Ann Intern Med. 1995;122(3):191-203.

Args:
root (str): The root directory where the dataset CSV file is stored.
tables (List[str]): A list of tables to be included (typically ["support2"]).
dataset_name (Optional[str]): The name of the dataset. Defaults to "support2".
config_path (Optional[str]): The path to the configuration file. If not provided,
uses the default config.
**kwargs: Additional arguments passed to BaseDataset.

Examples:
>>> from pyhealth.datasets import Support2Dataset
>>> dataset = Support2Dataset(
... root="/path/to/support2/data",
... tables=["support2"]
... )
>>> dataset.stats()

Attributes:
root (str): The root directory where the dataset is stored.
tables (List[str]): A list of tables to be included in the dataset.
dataset_name (str): The name of the dataset.
config_path (str): The path to the configuration file.
"""

def __init__(
self,
root: str,
tables: List[str],
dataset_name: Optional[str] = None,
config_path: Optional[str] = None,
**kwargs
) -> None:
if config_path is None:
logger.info("No config path provided, using default config")
config_path = Path(__file__).parent / "configs" / "support2.yaml"
super().__init__(
root=root,
tables=tables,
dataset_name=dataset_name or "support2",
config_path=config_path,
**kwargs
)
return

1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
MortalityPredictionMIMIC4,
MortalityPredictionOMOP,
)
from .survival_preprocess_support2 import SurvivalPreprocessSupport2
from .mortality_prediction_stagenet_mimic4 import (
MortalityPredictionStageNetMIMIC4,
)
Expand Down
Loading