|
| 1 | +# %% Loading MIMIC-IV dataset |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import polars as pl |
| 5 | +import torch |
| 6 | + |
| 7 | +from pyhealth.datasets import ( |
| 8 | + MIMIC4EHRDataset, |
| 9 | + get_dataloader, |
| 10 | + load_processors, |
| 11 | + split_by_patient, |
| 12 | +) |
| 13 | +from pyhealth.interpret.methods import GIM |
| 14 | +from pyhealth.models import StageNet |
| 15 | +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 |
| 16 | + |
| 17 | +# Configure dataset location and load cached processors |
| 18 | +dataset = MIMIC4EHRDataset( |
| 19 | + root="/home/logic/physionet.org/files/mimic-iv-demo/2.2/", |
| 20 | + tables=[ |
| 21 | + "patients", |
| 22 | + "admissions", |
| 23 | + "diagnoses_icd", |
| 24 | + "procedures_icd", |
| 25 | + "labevents", |
| 26 | + ], |
| 27 | +) |
| 28 | + |
| 29 | +# %% Setting StageNet Mortality Prediction Task |
| 30 | +input_processors, output_processors = load_processors("../resources/") |
| 31 | + |
| 32 | +sample_dataset = dataset.set_task( |
| 33 | + MortalityPredictionStageNetMIMIC4(), |
| 34 | + cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", |
| 35 | + input_processors=input_processors, |
| 36 | + output_processors=output_processors, |
| 37 | +) |
| 38 | +print(f"Total samples: {len(sample_dataset)}") |
| 39 | + |
| 40 | + |
| 41 | +def load_icd_description_map(dataset_root: str) -> dict: |
| 42 | + """Load ICD code → description mappings from reference tables.""" |
| 43 | + mapping = {} |
| 44 | + root_path = Path(dataset_root).expanduser() |
| 45 | + diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz" |
| 46 | + proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz" |
| 47 | + |
| 48 | + icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8} |
| 49 | + |
| 50 | + if diag_path.exists(): |
| 51 | + diag_df = pl.read_csv( |
| 52 | + diag_path, |
| 53 | + columns=["icd_code", "long_title"], |
| 54 | + dtypes=icd_dtype, |
| 55 | + ) |
| 56 | + mapping.update( |
| 57 | + zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list()) |
| 58 | + ) |
| 59 | + |
| 60 | + if proc_path.exists(): |
| 61 | + proc_df = pl.read_csv( |
| 62 | + proc_path, |
| 63 | + columns=["icd_code", "long_title"], |
| 64 | + dtypes=icd_dtype, |
| 65 | + ) |
| 66 | + mapping.update( |
| 67 | + zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list()) |
| 68 | + ) |
| 69 | + |
| 70 | + return mapping |
| 71 | + |
| 72 | + |
| 73 | +ICD_CODE_TO_DESC = load_icd_description_map(dataset.root) |
| 74 | + |
| 75 | +# %% Loading Pretrained StageNet model |
| 76 | +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| 77 | +model = StageNet( |
| 78 | + dataset=sample_dataset, |
| 79 | + embedding_dim=128, |
| 80 | + chunk_size=128, |
| 81 | + levels=3, |
| 82 | + dropout=0.3, |
| 83 | +) |
| 84 | + |
| 85 | +state_dict = torch.load("../resources/best.ckpt", map_location=device) |
| 86 | +model.load_state_dict(state_dict) |
| 87 | +model = model.to(device) |
| 88 | +model.eval() |
| 89 | +print(model) |
| 90 | + |
| 91 | +# %% Preparing dataloaders |
| 92 | +_, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42) |
| 93 | +test_loader = get_dataloader(test_data, batch_size=1, shuffle=False) |
| 94 | + |
| 95 | + |
| 96 | +def move_batch_to_device(batch, target_device): |
| 97 | + moved = {} |
| 98 | + for key, value in batch.items(): |
| 99 | + if isinstance(value, torch.Tensor): |
| 100 | + moved[key] = value.to(target_device) |
| 101 | + elif isinstance(value, tuple): |
| 102 | + moved[key] = tuple(v.to(target_device) for v in value) |
| 103 | + else: |
| 104 | + moved[key] = value |
| 105 | + return moved |
| 106 | + |
| 107 | + |
| 108 | +LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES |
| 109 | + |
| 110 | + |
| 111 | +def decode_token(idx: int, processor, feature_key: str): |
| 112 | + if processor is None or not hasattr(processor, "code_vocab"): |
| 113 | + return str(idx) |
| 114 | + reverse_vocab = {index: token for token, index in processor.code_vocab.items()} |
| 115 | + token = reverse_vocab.get(idx, f"<UNK:{idx}>") |
| 116 | + |
| 117 | + if feature_key == "icd_codes" and token not in {"<unk>", "<pad>"}: |
| 118 | + desc = ICD_CODE_TO_DESC.get(token) |
| 119 | + if desc: |
| 120 | + return f"{token}: {desc}" |
| 121 | + |
| 122 | + return token |
| 123 | + |
| 124 | + |
| 125 | +def unravel(flat_index: int, shape: torch.Size): |
| 126 | + coords = [] |
| 127 | + remaining = flat_index |
| 128 | + for dim in reversed(shape): |
| 129 | + coords.append(remaining % dim) |
| 130 | + remaining //= dim |
| 131 | + return list(reversed(coords)) |
| 132 | + |
| 133 | + |
| 134 | +def print_top_attributions( |
| 135 | + attributions, |
| 136 | + batch, |
| 137 | + processors, |
| 138 | + top_k: int = 10, |
| 139 | +): |
| 140 | + for feature_key, attr in attributions.items(): |
| 141 | + attr_cpu = attr.detach().cpu() |
| 142 | + if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: |
| 143 | + continue |
| 144 | + |
| 145 | + feature_input = batch[feature_key] |
| 146 | + if isinstance(feature_input, tuple): |
| 147 | + feature_input = feature_input[1] |
| 148 | + feature_input = feature_input.detach().cpu() |
| 149 | + |
| 150 | + flattened = attr_cpu[0].flatten() |
| 151 | + if flattened.numel() == 0: |
| 152 | + continue |
| 153 | + |
| 154 | + print(f"\nFeature: {feature_key}") |
| 155 | + k = min(top_k, flattened.numel()) |
| 156 | + top_values, top_indices = torch.topk(flattened.abs(), k=k) |
| 157 | + processor = processors.get(feature_key) if processors else None |
| 158 | + is_continuous = torch.is_floating_point(feature_input) |
| 159 | + |
| 160 | + for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1): |
| 161 | + attribution_value = flattened[flat_idx].item() |
| 162 | + coords = unravel(flat_idx.item(), attr_cpu[0].shape) |
| 163 | + |
| 164 | + if is_continuous: |
| 165 | + actual_value = feature_input[0][tuple(coords)].item() |
| 166 | + label = "" |
| 167 | + if feature_key == "labs" and len(coords) >= 1: |
| 168 | + lab_idx = coords[-1] |
| 169 | + if lab_idx < len(LAB_CATEGORY_NAMES): |
| 170 | + label = f"{LAB_CATEGORY_NAMES[lab_idx]} " |
| 171 | + print( |
| 172 | + f" {rank:2d}. idx={coords} {label}value={actual_value:.4f} " |
| 173 | + f"attr={attribution_value:+.6f}" |
| 174 | + ) |
| 175 | + else: |
| 176 | + token_idx = int(feature_input[0][tuple(coords)].item()) |
| 177 | + token = decode_token(token_idx, processor, feature_key) |
| 178 | + print( |
| 179 | + f" {rank:2d}. idx={coords} token='{token}' " |
| 180 | + f"attr={attribution_value:+.6f}" |
| 181 | + ) |
| 182 | + |
| 183 | + |
| 184 | +# %% Run GIM on a held-out sample |
| 185 | +gim = GIM(model, temperature=2.0) |
| 186 | + |
| 187 | +sample_batch = next(iter(test_loader)) |
| 188 | +sample_batch_device = move_batch_to_device(sample_batch, device) |
| 189 | + |
| 190 | +with torch.no_grad(): |
| 191 | + output = model(**sample_batch_device) |
| 192 | + probs = output["y_prob"] |
| 193 | + preds = torch.argmax(probs, dim=-1) |
| 194 | + label_key = model.label_key |
| 195 | + true_label = sample_batch_device[label_key] |
| 196 | + |
| 197 | + print("\nModel prediction for the sampled patient:") |
| 198 | + print(f" True label: {int(true_label.item())}") |
| 199 | + print(f" Predicted class: {int(preds.item())}") |
| 200 | + print(f" Probabilities: {probs[0].cpu().numpy()}") |
| 201 | + |
| 202 | +attributions = gim.attribute(**sample_batch_device) |
| 203 | +print_top_attributions(attributions, sample_batch_device, input_processors, top_k=10) |
| 204 | + |
| 205 | +# %% |
0 commit comments