diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index ed2845126..913439ee8 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -42,6 +42,7 @@ Available Datasets datasets/pyhealth.datasets.SHHSDataset datasets/pyhealth.datasets.SleepEDFDataset datasets/pyhealth.datasets.EHRShotDataset + datasets/pyhealth.datasets.Support2Dataset datasets/pyhealth.datasets.BMDHSDataset datasets/pyhealth.datasets.COVID19CXRDataset datasets/pyhealth.datasets.ChestXray14Dataset diff --git a/docs/api/datasets/pyhealth.datasets.Support2Dataset.rst b/docs/api/datasets/pyhealth.datasets.Support2Dataset.rst new file mode 100644 index 000000000..01232af50 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.Support2Dataset.rst @@ -0,0 +1,15 @@ +pyhealth.datasets.Support2Dataset +================================== + +Overview +-------- + +The SUPPORT2 (Study to Understand Prognoses and Preferences for Outcomes and Risks of Treatments) dataset contains data on seriously ill hospitalized adults. It includes patient demographics, diagnoses, clinical measurements, and outcomes such as survival and hospital mortality. + +The dataset is commonly used for mortality prediction, length of stay prediction, and other clinical outcome prediction tasks. + +.. autoclass:: pyhealth.datasets.Support2Dataset + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 8fef3506f..642fbd7c4 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -88,6 +88,18 @@ Readmission Prediction * - ``readmission_mimic3_fairness.py`` - Fairness-aware readmission prediction on MIMIC-III +Survival Prediction +------------------- + +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Example File + - Description + * - ``survival_preprocess_support2_demo.py`` + - Survival probability prediction preprocessing with SUPPORT2 dataset. Demonstrates feature extraction (demographics, vitals, labs, scores, comorbidities) and ground truth survival probability labels for 2-month and 6-month horizons. Shows how to decode processed tensors back to human-readable features. + Drug Recommendation ------------------- diff --git a/examples/survival_preprocess_support2_demo.py b/examples/survival_preprocess_support2_demo.py new file mode 100644 index 000000000..c3b7c1803 --- /dev/null +++ b/examples/survival_preprocess_support2_demo.py @@ -0,0 +1,138 @@ +""" +Demo script for survival prediction preprocessing using SUPPORT2 dataset. + +This example demonstrates how to: +1. Load the SUPPORT2 dataset (using test data with 3 patients) +2. Apply the preprocessing task to extract features and labels +3. Examine preprocessed samples ready for model training + +The preprocessing task extracts: +- Features from raw patient data (demographics, vitals, labs, scores, etc.) +- Ground truth survival probabilities from the dataset (surv2m/surv6m fields) +- Structures data into samples ready for training a prediction model + +Note: The survival probabilities shown are ground truth labels extracted from the +dataset (surv2m/surv6m columns). These are the target variables that a model +would learn to predict from the extracted features. + +This example uses the synthetic test dataset from test-resources (3 patients). +For real usage, replace the path with your actual SUPPORT2 dataset. +""" + +import warnings +import logging +from pathlib import Path + +# Suppress warnings and reduce logging verbosity +warnings.filterwarnings("ignore") +logging.basicConfig(level=logging.WARNING) +logging.getLogger("pyhealth").setLevel(logging.WARNING) +logging.getLogger("pyhealth.datasets").setLevel(logging.WARNING) +logging.getLogger("pyhealth.datasets.support2").setLevel(logging.WARNING) +logging.getLogger("pyhealth.datasets.base_dataset").setLevel(logging.WARNING) + +# Import pyhealth modules +from pyhealth.datasets import Support2Dataset +from pyhealth.tasks import SurvivalPreprocessSupport2 + +# Suppress tqdm progress bars for cleaner output +try: + def noop_tqdm(iterable, *args, **kwargs): + return iterable + from pyhealth.datasets import base_dataset, sample_dataset + base_dataset.tqdm = noop_tqdm + sample_dataset.tqdm = noop_tqdm + import tqdm + tqdm.tqdm = noop_tqdm +except (ImportError, AttributeError): + pass + +# Step 1: Load dataset using test data +print("=" * 70) +print("Step 1: Load SUPPORT2 Dataset") +print("=" * 70) +script_dir = Path(__file__).parent +test_data_path = script_dir.parent / "test-resources" / "core" / "support2" + +dataset = Support2Dataset( + root=str(test_data_path), + tables=["support2"], +) + +print(f"Loaded dataset with {len(dataset.unique_patient_ids)} patients\n") + +# Step 2: Apply preprocessing task to extract features and labels (2-month horizon) +print("=" * 70) +print("Step 2: Apply Survival Preprocessing Task") +print("=" * 70) +task = SurvivalPreprocessSupport2(time_horizon="2m") +sample_dataset = dataset.set_task(task=task) + +print(f"Generated {len(sample_dataset)} samples") +print(f"Input schema: {sample_dataset.input_schema}") +print(f"Output schema: {sample_dataset.output_schema}\n") + +# Helper function to decode tensor indices to feature strings +def decode_features(tensor, processor): + """Decode tensor indices back to original feature strings.""" + if processor is None or not hasattr(processor, 'code_vocab'): + return [str(idx.item()) for idx in tensor] + reverse_vocab = {idx: token for token, idx in processor.code_vocab.items()} + return [reverse_vocab.get(idx.item(), f"") for idx in tensor] + +# Step 3: Display features for all samples +print("=" * 70) +print("Step 3: Examine Preprocessed Samples") +print("=" * 70) +# Sort samples by patient_id to ensure consistent order +samples = sorted(sample_dataset, key=lambda x: int(x['patient_id'])) +for sample in samples: + # Display patient ID and tensor shapes first + print(f"\nPatient {sample['patient_id']}:") + print(f" Demographics tensor shape: {sample['demographics'].shape}") + print(f" Disease codes tensor shape: {sample['disease_codes'].shape}") + print(f" Vitals tensor shape: {sample['vitals'].shape}") + print(f" Labs tensor shape: {sample['labs'].shape}") + print(f" Scores tensor shape: {sample['scores'].shape}") + print(f" Comorbidities tensor shape: {sample['comorbidities'].shape}") + + # Decode and display features for this sample + demographics = decode_features( + sample['demographics'], + sample_dataset.input_processors.get('demographics') + ) + disease_codes = decode_features( + sample['disease_codes'], + sample_dataset.input_processors.get('disease_codes') + ) + vitals = decode_features( + sample['vitals'], + sample_dataset.input_processors.get('vitals') + ) + labs = decode_features( + sample['labs'], + sample_dataset.input_processors.get('labs') + ) + scores = decode_features( + sample['scores'], + sample_dataset.input_processors.get('scores') + ) + comorbidities = decode_features( + sample['comorbidities'], + sample_dataset.input_processors.get('comorbidities') + ) + + # Display decoded features + print(f" Demographics: {', '.join(demographics)}") + print(f" Disease Codes: {', '.join(disease_codes)}") + print(f" Vitals: {', '.join(vitals)}") + print(f" Labs: {', '.join(labs)}") + print(f" Scores: {', '.join(scores)}") + print(f" Comorbidities: {', '.join(comorbidities)}") + print(f" Survival Probability (2m): {sample['survival_probability'].item():.4f}") + +print("\n") +print("=" * 70) +print("Preprocessing Complete!") +print("=" * 70) +print("The samples are ready for model training.") diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 91a8da937..ced02afd7 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -63,6 +63,7 @@ def __init__(self, *args, **kwargs): from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset +from .support2 import Support2Dataset from .splitter import ( split_by_patient, split_by_patient_conformal, diff --git a/pyhealth/datasets/configs/support2.yaml b/pyhealth/datasets/configs/support2.yaml new file mode 100644 index 000000000..f61e730e2 --- /dev/null +++ b/pyhealth/datasets/configs/support2.yaml @@ -0,0 +1,55 @@ +version: "1.0" +tables: + support2: + file_path: "support2.csv" + patient_id: "sno" + timestamp: null + attributes: + - "d.time" + - "age" + - "death" + - "sex" + - "hospdead" + - "slos" + - "dzgroup" + - "dzclass" + - "num.co" + - "edu" + - "income" + - "scoma" + - "charges" + - "totcst" + - "totmcst" + - "avtisst" + - "race" + - "sps" + - "aps" + - "surv2m" + - "surv6m" + - "hday" + - "diabetes" + - "dementia" + - "ca" + - "prg2m" + - "prg6m" + - "dnr" + - "dnrday" + - "meanbp" + - "wblc" + - "hrt" + - "resp" + - "temp" + - "pafi" + - "alb" + - "bili" + - "crea" + - "sod" + - "ph" + - "glucose" + - "bun" + - "urine" + - "adlp" + - "adls" + - "sfdm2" + - "adlsc" + diff --git a/pyhealth/datasets/support2.py b/pyhealth/datasets/support2.py new file mode 100644 index 000000000..72c5b095b --- /dev/null +++ b/pyhealth/datasets/support2.py @@ -0,0 +1,72 @@ +import logging +from pathlib import Path +from typing import List, Optional + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class Support2Dataset(BaseDataset): + """ + A dataset class for handling SUPPORT2 (Study to Understand Prognoses and Preferences + for Outcomes and Risks of Treatments) data. + + The SUPPORT2 dataset contains data on 9,105 seriously ill hospitalized adults from + five U.S. medical centers (1989-1994), including patient demographics, diagnoses, + clinical measurements, and outcomes. + + Dataset is available for download from: + - UCI Machine Learning Repository: https://archive.ics.uci.edu/dataset/880/support2 + - Vanderbilt Biostatistics: https://hbiostat.org/data/repo/supportdesc + - Hugging Face: https://huggingface.co/datasets/jarrydmartinx/support2 + - R packages: "rms" and "Hmisc" + + Citation: + Knaus WA, Harrell FE, Lynn J, et al. The SUPPORT prognostic model: + Objective estimates of survival for seriously ill hospitalized adults. + Ann Intern Med. 1995;122(3):191-203. + + Args: + root (str): The root directory where the dataset CSV file is stored. + tables (List[str]): A list of tables to be included (typically ["support2"]). + dataset_name (Optional[str]): The name of the dataset. Defaults to "support2". + config_path (Optional[str]): The path to the configuration file. If not provided, + uses the default config. + **kwargs: Additional arguments passed to BaseDataset. + + Examples: + >>> from pyhealth.datasets import Support2Dataset + >>> dataset = Support2Dataset( + ... root="/path/to/support2/data", + ... tables=["support2"] + ... ) + >>> dataset.stats() + + Attributes: + root (str): The root directory where the dataset is stored. + tables (List[str]): A list of tables to be included in the dataset. + dataset_name (str): The name of the dataset. + config_path (str): The path to the configuration file. + """ + + def __init__( + self, + root: str, + tables: List[str], + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs + ) -> None: + if config_path is None: + logger.info("No config path provided, using default config") + config_path = Path(__file__).parent / "configs" / "support2.yaml" + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "support2", + config_path=config_path, + **kwargs + ) + return + diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 639426aaa..82c890e5e 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -37,6 +37,7 @@ MortalityPredictionMIMIC4, MortalityPredictionOMOP, ) +from .survival_preprocess_support2 import SurvivalPreprocessSupport2 from .mortality_prediction_stagenet_mimic4 import ( MortalityPredictionStageNetMIMIC4, ) diff --git a/pyhealth/tasks/survival_preprocess_support2.py b/pyhealth/tasks/survival_preprocess_support2.py new file mode 100644 index 000000000..7861e34ea --- /dev/null +++ b/pyhealth/tasks/survival_preprocess_support2.py @@ -0,0 +1,330 @@ +import logging +from typing import Any, Dict, List, Optional + +from .base_task import BaseTask + +logger = logging.getLogger(__name__) + + +class SurvivalPreprocessSupport2(BaseTask): + """Preprocessing task for survival probability prediction models using SUPPORT2 dataset. + + This task extracts features and labels from raw patient data to prepare samples + for model training. It extracts patient demographics, diagnoses, clinical + measurements, and vital signs, and pairs them with ground truth survival + probabilities from the dataset (surv2m or surv6m fields). + + The task performs feature extraction and data structuring: + - Extracts features from raw patient data (demographics, vitals, labs, scores, etc.) + - Extracts ground truth survival probabilities from surv2m/surv6m fields + - Structures data into samples ready for model training + + The SUPPORT2 dataset contains data on seriously ill hospitalized adults, + with each patient represented by a single record at admission. + + Task Schema: + Input: + - demographics: sequence of demographic features (age, sex, race, education, income) + - disease_codes: sequence of disease group and class codes + - vitals: sequence of vital signs (meanbp, hrt, resp, temp, pafi) + - labs: sequence of lab values (wblc, alb, bili, crea, sod, ph, glucose, bun) + - scores: sequence of clinical scores (sps, aps, scoma) + - comorbidities: sequence of comorbidity indicators (diabetes, dementia, ca) + Output: + - survival_probability: regression label (0-1, ground truth survival probability) + + Args: + time_horizon (str): Which survival probability to extract. + Options: "2m" (2 months) or "6m" (6 months). Default is "2m". + + Returns: + List[Dict[str, Any]]: A list containing a single sample per patient with: + - patient_id: The patient identifier + - demographics: List of demographic feature strings (e.g., ["age_62.85", "sex_male"]) + - disease_codes: List of disease code strings + - vitals: List of vital sign strings + - labs: List of lab value strings + - scores: List of clinical score strings + - comorbidities: List of comorbidity indicator strings + - survival_probability: Float between 0 and 1 (ground truth label) + + Examples: + >>> from pyhealth.datasets import Support2Dataset + >>> from pyhealth.tasks import SurvivalPreprocessSupport2 + >>> + >>> # Step 1: Load SUPPORT2 dataset + >>> print("Step 1: Load SUPPORT2 Dataset") + >>> # For real usage, use your dataset path: + >>> # dataset = Support2Dataset( + >>> # root="/path/to/support2/data", + >>> # tables=["support2"] + >>> # ) + >>> # For local testing with test data: + >>> from pathlib import Path + >>> test_data_path = Path("test-resources/core/support2") + >>> dataset = Support2Dataset( + ... root=str(test_data_path), + ... tables=["support2"] + ... ) + >>> print(f"Loaded dataset with {len(dataset.unique_patient_ids)} patients\n") + >>> + >>> # Step 2: Apply preprocessing task to extract features and labels + >>> print("Step 2: Apply Survival Preprocessing Task") + >>> task = SurvivalPreprocessSupport2(time_horizon="2m") + >>> sample_dataset = dataset.set_task(task=task) + >>> print(f"Generated {len(sample_dataset)} samples") + >>> print(f"Input schema: {sample_dataset.input_schema}") + >>> print(f"Output schema: {sample_dataset.output_schema}\n") + >>> + >>> # Helper function to decode tensor indices to feature strings + >>> def decode_features(tensor, processor): + ... if processor is None or not hasattr(processor, 'code_vocab'): + ... return [str(idx.item()) for idx in tensor] + ... reverse_vocab = {idx: token for token, idx in processor.code_vocab.items()} + ... return [reverse_vocab.get(idx.item(), f"") for idx in tensor] + >>> + >>> # Step 3: Display features for one sample + >>> print("Step 3: Examine Preprocessed Samples") + >>> sample = sample_dataset[0] + >>> print(f"Patient {sample['patient_id']}:") + >>> print(f"Demographics tensor shape: {sample['demographics'].shape}") + >>> print(f"Disease codes tensor shape: {sample['disease_codes'].shape}") + >>> print(f"Vitals tensor shape: {sample['vitals'].shape}") + >>> print(f"Labs tensor shape: {sample['labs'].shape}") + >>> print(f"Scores tensor shape: {sample['scores'].shape}") + >>> print(f"Comorbidities tensor shape: {sample['comorbidities'].shape}") + >>> + >>> # Decode and display features for this sample + >>> demographics_decoded = decode_features( + ... sample['demographics'], + ... sample_dataset.input_processors.get('demographics') + ... ) + >>> print(f" Demographics: {', '.join(demographics_decoded)}") + >>> disease_codes_decoded = decode_features( + ... sample['disease_codes'], + ... sample_dataset.input_processors.get('disease_codes') + ... ) + >>> print(f" Disease Codes: {', '.join(disease_codes_decoded)}") + >>> vitals_decoded = decode_features( + ... sample['vitals'], + ... sample_dataset.input_processors.get('vitals') + ... ) + >>> print(f" Vitals: {', '.join(vitals_decoded)}") + >>> print(f" Survival Probability (2m): {sample['survival_probability'].item():.4f}") + >>> + >>> # For a complete working example displaying all feature groups for all samples, + >>> # see: examples/survival_preprocess_support2_demo.py + + Note: + - Each patient produces exactly one sample (single-row-per-patient dataset) + - Missing values in labs/vitals are handled by excluding None values + - Survival probabilities are ground truth labels extracted from surv2m/surv6m fields + - Processors will automatically convert string features to tensors for model training + """ + + task_name: str = "SurvivalPreprocessSupport2" + input_schema: Dict[str, str] = { + "demographics": "sequence", + "disease_codes": "sequence", + "vitals": "sequence", + "labs": "sequence", + "scores": "sequence", + "comorbidities": "sequence", + } + output_schema: Dict[str, str] = {"survival_probability": "regression"} + + def __init__(self, time_horizon: str = "2m"): + """Initialize the SurvivalPreprocessSupport2 preprocessing task. + + Args: + time_horizon (str): Which survival probability to extract as the label. + Options: "2m" (2 months) or "6m" (6 months). Default is "2m". + """ + super().__init__() + self.time_horizon = time_horizon + if time_horizon == "2m": + self.survival_field = "surv2m" + self.task_name = "SurvivalPreprocessSupport2_2m" + elif time_horizon == "6m": + self.survival_field = "surv6m" + self.task_name = "SurvivalPreprocessSupport2_6m" + else: + raise ValueError( + f"time_horizon must be '2m' or '6m', got {time_horizon}" + ) + + def _clean_value(self, value: Any) -> Optional[float]: + """Clean a value by converting to float, handling None and empty strings. + + Args: + value: The value to clean + + Returns: + Optional[float]: Cleaned float value, or None if invalid + """ + if value is None: + return None + if isinstance(value, str): + value = value.strip() + if value == "" or value.lower() == "none": + return None + try: + return float(value) + except (ValueError, TypeError): + return None + + def _get_attr_safe(self, event: Any, attr: str) -> Any: + """Safely get an attribute from an event. + + Args: + event: The event object + attr: The attribute name + + Returns: + The attribute value or None if not present + """ + try: + return getattr(event, attr, None) + except AttributeError: + try: + return event[attr] + except (KeyError, TypeError): + return None + + def _extract_numeric_features( + self, event: Any, features: Dict[str, str] + ) -> List[str]: + """Extract numeric features from an event and format them. + + Args: + event: The event object + features: Dict mapping prefix to attribute name (e.g., {"age": "age"}) + + Returns: + List of formatted feature strings (e.g., ["age_62.85"]) + """ + result = [] + for prefix, attr in features.items(): + value = self._clean_value(self._get_attr_safe(event, attr)) + if value is not None: + result.append(f"{prefix}_{value}") + return result + + def _extract_string_features( + self, event: Any, features: Dict[str, str] + ) -> List[str]: + """Extract string features from an event and format them. + + Args: + event: The event object + features: Dict mapping prefix to attribute name (e.g., {"sex": "sex"}) + + Returns: + List of formatted feature strings (e.g., ["sex_male"]) + """ + result = [] + for prefix, attr in features.items(): + value = self._get_attr_safe(event, attr) + if value is not None: + result.append(f"{prefix}_{str(value)}") + return result + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Extracts features and labels from a patient for survival probability prediction models. + + This is a preprocessing step that structures raw patient data into + features (demographics, vitals, labs, etc.) and ground truth labels + (survival probabilities) ready for model training. + + Args: + patient (Any): A Patient object containing SUPPORT2 data. + + Returns: + List[Dict[str, Any]]: A list containing a single sample per patient + with extracted features and survival probability label. + """ + # Get the single support2 event per patient + events = patient.get_events(event_type="support2") + + if len(events) == 0: + return [] + + # Should be exactly one event per patient (single-row-per-patient dataset) + if len(events) > 1: + logger.warning( + f"Patient {patient.patient_id} has {len(events)} support2 events, " + "expected 1. Using first event." + ) + + event = events[0] + + # Extract demographics (mixed numeric and string features) + demographics = [] + demographics.extend( + self._extract_numeric_features(event, {"age": "age", "edu": "edu"}) + ) + demographics.extend( + self._extract_string_features(event, {"sex": "sex", "race": "race", "income": "income"}) + ) + + # Extract disease codes + disease_codes = self._extract_string_features( + event, {"dzgroup": "dzgroup", "dzclass": "dzclass"} + ) + + # Extract vital signs + vitals = self._extract_numeric_features( + event, {"meanbp": "meanbp", "hrt": "hrt", "resp": "resp", "temp": "temp", "pafi": "pafi"} + ) + + # Extract lab values + labs = self._extract_numeric_features( + event, + { + "wblc": "wblc", + "alb": "alb", + "bili": "bili", + "crea": "crea", + "sod": "sod", + "ph": "ph", + "glucose": "glucose", + "bun": "bun", + }, + ) + + # Extract clinical scores + scores = self._extract_numeric_features( + event, {"sps": "sps", "aps": "aps", "scoma": "scoma"} + ) + + # Extract comorbidities (mixed numeric and string features) + comorbidities = [] + comorbidities.extend( + self._extract_numeric_features(event, {"diabetes": "diabetes", "dementia": "dementia"}) + ) + comorbidities.extend(self._extract_string_features(event, {"ca": "ca"})) + + # Extract ground truth survival probability label from dataset + # (surv2m or surv6m field contains pre-computed survival probabilities) + survival_prob = self._clean_value(self._get_attr_safe(event, self.survival_field)) + + # Skip if survival probability is missing + if survival_prob is None: + return [] + + # Ensure survival probability is in valid range [0, 1] + survival_prob = max(0.0, min(1.0, survival_prob)) + + # Create single sample per patient + sample = { + "patient_id": patient.patient_id, + "demographics": demographics, + "disease_codes": disease_codes, + "vitals": vitals, + "labs": labs, + "scores": scores, + "comorbidities": comorbidities, + "survival_probability": survival_prob, + } + + return [sample] diff --git a/test-resources/core/support2/support2.csv b/test-resources/core/support2/support2.csv new file mode 100644 index 000000000..2a0951101 --- /dev/null +++ b/test-resources/core/support2/support2.csv @@ -0,0 +1,4 @@ +sno,age,death,sex,hospdead,slos,d.time,dzgroup,dzclass,num.co,edu,income,scoma,charges,totcst,totmcst,avtisst,race,sps,aps,surv2m,surv6m,hday,diabetes,dementia,ca,prg2m,prg6m,dnr,dnrday,meanbp,wblc,hrt,resp,temp,pafi,alb,bili,crea,sod,ph,glucose,bun,urine,adlp,adls,sfdm2,adlsc +1,62.85,0,male,0,5,2029,Lung Cancer,Cancer,0,11,$11-$25k,0,9715.0,,,7.0,other,33.9,20,0.26,0.037,1,0,0,metastatic,0.5,0.25,no dnr,5,97.0,6.0,69.0,22,36.0,388.0,1.8,0.2,1.2,141,7.46,,,,7,7,,7.0 +2,60.34,1,female,1,4,4,Cirrhosis,COPD/CHF/Cirrhosis,2,12,$11-$25k,44,34496.0,,,29.0,white,52.7,74,0.001,0.0,3,0,0,no,0.0,0.0,,,43.0,17.1,112.0,34,34.6,98.0,,,5.5,132,7.25,,,,1,<2 mo. follow-up,1.0 +3,52.75,1,female,0,17,47,Cirrhosis,COPD/CHF/Cirrhosis,2,12,under $11k,0,41094.0,,,13.0,white,20.5,45,0.79,0.66,4,0,0,no,0.75,0.5,no dnr,17,70.0,8.5,88.0,28,37.4,231.7,,2.2,2.0,134,7.46,,,,1,0,<2 mo. follow-up,0.0 diff --git a/tests/core/test_support2.py b/tests/core/test_support2.py new file mode 100644 index 000000000..6a21c84f6 --- /dev/null +++ b/tests/core/test_support2.py @@ -0,0 +1,211 @@ +import unittest +import tempfile +import shutil +import io +import sys +from pathlib import Path +import pandas as pd +import torch + +from pyhealth.datasets import Support2Dataset +from pyhealth.tasks import SurvivalPreprocessSupport2 + + +class TestSupport2Dataset(unittest.TestCase): + """Test Support2Dataset with synthetic test data.""" + + def setUp(self): + """Set up test data files and directory structure with synthetic data.""" + self.temp_dir = tempfile.mkdtemp() + self.root = Path(self.temp_dir) + + support2_data = { + 'sno': ['1', '2', '3'], + 'age': [62.85, 60.34, 52.75], + 'death': [0, 1, 1], + 'sex': ['male', 'female', 'female'], + 'hospdead': [0, 1, 0], + 'slos': [5, 4, 17], + 'd.time': [2029, 4, 47], + 'dzgroup': ['Lung Cancer', 'Cirrhosis', 'Cirrhosis'], + 'dzclass': ['Cancer', 'COPD/CHF/Cirrhosis', 'COPD/CHF/Cirrhosis'], + 'num.co': [0, 2, 2], + 'edu': [11, 12, 12], + 'income': ['$11-$25k', '$11-$25k', 'under $11k'], + 'scoma': [0, 44, 0], + 'charges': [9715.0, 34496.0, 41094.0], + 'race': ['other', 'white', 'white'], + 'sps': [33.9, 52.7, 20.5], + 'aps': [20, 74, 45], + 'surv2m': [0.26, 0.001, 0.79], + 'surv6m': [0.037, 0.0, 0.66], + 'hday': [1, 3, 4], + 'diabetes': [0, 0, 0], + 'dementia': [0, 0, 0], + 'ca': ['metastatic', 'no', 'no'], + 'meanbp': [97.0, 43.0, 70.0], + 'wblc': [6.0, 17.1, 8.5], + 'hrt': [69.0, 112.0, 88.0], + 'resp': [22, 34, 28], + 'temp': [36.0, 34.6, 37.4], + 'pafi': [388.0, 98.0, 231.7], + 'alb': [1.8, None, None], + 'bili': [0.2, None, 2.2], + 'crea': [1.2, 5.5, 2.0], + 'sod': [141, 132, 134], + 'ph': [7.46, 7.25, 7.46], + 'glucose': [None, None, None], + 'bun': [None, None, None], + 'urine': [None, None, None], + 'adlp': [7, 1, 1], + 'adls': [7, None, 0], + 'sfdm2': [None, '<2 mo. follow-up', '<2 mo. follow-up'], + 'adlsc': [7.0, 1.0, 0.0], + 'prg2m': [0.5, 0.0, 0.75], + 'prg6m': [0.25, 0.0, 0.5], + 'dnr': ['no dnr', None, 'no dnr'], + 'dnrday': [5, None, 17], + 'totcst': [None, None, None], + 'totmcst': [None, None, None], + 'avtisst': [7.0, 29.0, 13.0] + } + + support2_df = pd.DataFrame(support2_data) + support2_df.to_csv(self.root / "support2.csv", index=False) + + def tearDown(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir) + + def test_dataset_initialization(self): + """Test Support2Dataset initialization.""" + print("Dataset Tests:") + dataset = Support2Dataset(root=str(self.root), tables=["support2"]) + self.assertIsInstance(dataset, Support2Dataset) + self.assertEqual(dataset.dataset_name, "support2") + print("Test passed: dataset_initialization\n") + + def test_load_data(self): + """Test that data loads correctly.""" + dataset = Support2Dataset(root=str(self.root), tables=["support2"]) + self.assertIsNotNone(dataset.global_event_df) + print("Test passed: load_data\n") + + def test_patient_count(self): + """Test that the dataset contains the expected number of patients.""" + dataset = Support2Dataset(root=str(self.root), tables=["support2"]) + unique_patients = dataset.unique_patient_ids + self.assertEqual(len(unique_patients), 3) + print("Test passed: patient_count\n") + + def test_get_patient(self): + """Test retrieving a single patient by ID.""" + dataset = Support2Dataset(root=str(self.root), tables=["support2"]) + patient = dataset.get_patient("1") + self.assertIsNotNone(patient) + self.assertEqual(patient.patient_id, "1") + print("Test passed: get_patient\n") + + def test_stats(self): + """Test that stats method executes without errors.""" + dataset = Support2Dataset(root=str(self.root), tables=["support2"]) + # Suppress stats output + old_stdout = sys.stdout + sys.stdout = io.StringIO() + dataset.stats() + sys.stdout = old_stdout + print("Test passed: stats\n") + + def test_survival_preprocess_2m(self): + """Test SurvivalPreprocessSupport2 task with 2-month horizon.""" + print("\nTask Tests:") + dataset = Support2Dataset(root=str(self.root), tables=["support2"]) + + task = SurvivalPreprocessSupport2(time_horizon="2m") + self.assertEqual(task.task_name, "SurvivalPreprocessSupport2_2m") + self.assertEqual(task.survival_field, "surv2m") + + # Test input and output schemas + self.assertIn("demographics", task.input_schema) + self.assertIn("disease_codes", task.input_schema) + self.assertIn("vitals", task.input_schema) + self.assertIn("labs", task.input_schema) + self.assertIn("scores", task.input_schema) + self.assertIn("comorbidities", task.input_schema) + self.assertIn("survival_probability", task.output_schema) + self.assertEqual(task.output_schema["survival_probability"], "regression") + + sample_dataset = dataset.set_task(task) + self.assertIsNotNone(sample_dataset) + self.assertTrue(hasattr(sample_dataset, "samples")) + self.assertEqual(len(sample_dataset.samples), 3) + + # Check first sample structure + sample = sample_dataset.samples[0] + required_keys = [ + "patient_id", + "demographics", + "disease_codes", + "vitals", + "labs", + "scores", + "comorbidities", + "survival_probability", + ] + for key in required_keys: + self.assertIn(key, sample, f"Sample should contain key: {key}") + + # Verify survival probabilities are in valid range [0, 1] + for s in sample_dataset.samples: + survival_prob = s["survival_probability"] + self.assertIsInstance(survival_prob, torch.Tensor) + prob_value = survival_prob.item() + self.assertGreaterEqual(prob_value, 0.0) + self.assertLessEqual(prob_value, 1.0) + + # Check that features are tensors after processing + self.assertIsInstance(sample["demographics"], torch.Tensor) + self.assertIsInstance(sample["disease_codes"], torch.Tensor) + self.assertIsInstance(sample["vitals"], torch.Tensor) + self.assertIsInstance(sample["labs"], torch.Tensor) + self.assertIsInstance(sample["scores"], torch.Tensor) + self.assertIsInstance(sample["comorbidities"], torch.Tensor) + + # Check that tensors are non-empty + self.assertGreater(len(sample["demographics"]), 0) + self.assertGreater(len(sample["disease_codes"]), 0) + print("Test passed: survival_preprocess_2m\n") + + def test_survival_preprocess_6m(self): + """Test SurvivalPreprocessSupport2 task with 6-month horizon.""" + dataset = Support2Dataset(root=str(self.root), tables=["support2"]) + + task = SurvivalPreprocessSupport2(time_horizon="6m") + self.assertEqual(task.task_name, "SurvivalPreprocessSupport2_6m") + self.assertEqual(task.survival_field, "surv6m") + + sample_dataset = dataset.set_task(task) + self.assertIsNotNone(sample_dataset) + self.assertEqual(len(sample_dataset.samples), 3) + + # Verify all samples have valid survival probabilities + for s in sample_dataset.samples: + survival_prob = s["survival_probability"] + self.assertIsInstance(survival_prob, torch.Tensor) + prob_value = survival_prob.item() + self.assertGreaterEqual(prob_value, 0.0) + self.assertLessEqual(prob_value, 1.0) + print("Test passed: survival_preprocess_6m\n") + + def test_survival_preprocess_invalid_horizon(self): + """Test that invalid time_horizon raises ValueError.""" + with self.assertRaises(ValueError): + SurvivalPreprocessSupport2(time_horizon="12m") + + with self.assertRaises(ValueError): + SurvivalPreprocessSupport2(time_horizon="invalid") + print("Test passed: survival_preprocess_invalid_horizon\n") + + +if __name__ == "__main__": + unittest.main()