|
| 1 | +# %% Loading MIMIC-IV dataset |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import polars as pl |
| 5 | +import torch |
| 6 | + |
| 7 | +from pyhealth.datasets import ( |
| 8 | + MIMIC4EHRDataset, |
| 9 | + get_dataloader, |
| 10 | + load_processors, |
| 11 | + split_by_patient, |
| 12 | +) |
| 13 | +from pyhealth.interpret.methods import GIM |
| 14 | +from pyhealth.models import Transformer |
| 15 | +from pyhealth.tasks import MortalityPredictionMIMIC4 |
| 16 | + |
| 17 | + |
| 18 | +def maybe_load_processors(resource_dir: str, task): |
| 19 | + """Load cached processors if they match the current task schema.""" |
| 20 | + |
| 21 | + try: |
| 22 | + input_processors, output_processors = load_processors(resource_dir) |
| 23 | + except Exception as exc: |
| 24 | + print(f"Falling back to rebuilding processors: {exc}") |
| 25 | + return None, None |
| 26 | + |
| 27 | + expected_inputs = set(task.input_schema.keys()) |
| 28 | + expected_outputs = set(task.output_schema.keys()) |
| 29 | + if set(input_processors.keys()) != expected_inputs: |
| 30 | + print( |
| 31 | + "Cached input processors do not match MortalityPredictionMIMIC4 schema; rebuilding." |
| 32 | + ) |
| 33 | + return None, None |
| 34 | + if set(output_processors.keys()) != expected_outputs: |
| 35 | + print( |
| 36 | + "Cached output processors do not match MortalityPredictionMIMIC4 schema; rebuilding." |
| 37 | + ) |
| 38 | + return None, None |
| 39 | + return input_processors, output_processors |
| 40 | + |
| 41 | + |
| 42 | +# Configure dataset location and optionally load cached processors |
| 43 | +dataset = MIMIC4EHRDataset( |
| 44 | + root="/home/logic/physionet.org/files/mimic-iv-demo/2.2/", |
| 45 | + tables=[ |
| 46 | + "patients", |
| 47 | + "admissions", |
| 48 | + "diagnoses_icd", |
| 49 | + "procedures_icd", |
| 50 | + "prescriptions", |
| 51 | + ], |
| 52 | +) |
| 53 | + |
| 54 | +task = MortalityPredictionMIMIC4() |
| 55 | +input_processors, output_processors = maybe_load_processors("../resources/", task) |
| 56 | + |
| 57 | +sample_dataset = dataset.set_task( |
| 58 | + task, |
| 59 | + cache_dir="~/.cache/pyhealth/mimic4_transformer_mortality", |
| 60 | + input_processors=input_processors, |
| 61 | + output_processors=output_processors, |
| 62 | +) |
| 63 | +print(f"Total samples: {len(sample_dataset)}") |
| 64 | + |
| 65 | + |
| 66 | +def load_icd_description_map(dataset_root: str) -> dict: |
| 67 | + """Load ICD code → description mappings from reference tables.""" |
| 68 | + |
| 69 | + mapping = {} |
| 70 | + root_path = Path(dataset_root).expanduser() |
| 71 | + diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz" |
| 72 | + proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz" |
| 73 | + |
| 74 | + icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8} |
| 75 | + |
| 76 | + if diag_path.exists(): |
| 77 | + diag_df = pl.read_csv( |
| 78 | + diag_path, |
| 79 | + columns=["icd_code", "long_title"], |
| 80 | + dtypes=icd_dtype, |
| 81 | + ) |
| 82 | + mapping.update( |
| 83 | + zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list()) |
| 84 | + ) |
| 85 | + |
| 86 | + if proc_path.exists(): |
| 87 | + proc_df = pl.read_csv( |
| 88 | + proc_path, |
| 89 | + columns=["icd_code", "long_title"], |
| 90 | + dtypes=icd_dtype, |
| 91 | + ) |
| 92 | + mapping.update( |
| 93 | + zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list()) |
| 94 | + ) |
| 95 | + |
| 96 | + return mapping |
| 97 | + |
| 98 | + |
| 99 | +ICD_CODE_TO_DESC = load_icd_description_map(dataset.root) |
| 100 | + |
| 101 | +# %% Loading Pretrained Transformer model |
| 102 | +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| 103 | +model = Transformer( |
| 104 | + dataset=sample_dataset, |
| 105 | + embedding_dim=128, |
| 106 | + heads=2, |
| 107 | + dropout=0.2, |
| 108 | + num_layers=2, |
| 109 | +) |
| 110 | + |
| 111 | +ckpt_path = Path("../resources/transformer_best.ckpt") |
| 112 | +if not ckpt_path.exists(): |
| 113 | + raise FileNotFoundError( |
| 114 | + f"Missing pretrained weights at {ckpt_path}. " |
| 115 | + "Train the Transformer model and place the checkpoint in ../resources/." |
| 116 | + ) |
| 117 | +state_dict = torch.load(str(ckpt_path), map_location=device) |
| 118 | +model.load_state_dict(state_dict) |
| 119 | +model = model.to(device) |
| 120 | +model.eval() |
| 121 | +print(model) |
| 122 | + |
| 123 | +# %% Preparing dataloaders |
| 124 | +_, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42) |
| 125 | +test_loader = get_dataloader(test_data, batch_size=1, shuffle=False) |
| 126 | + |
| 127 | + |
| 128 | +def move_batch_to_device(batch, target_device): |
| 129 | + moved = {} |
| 130 | + for key, value in batch.items(): |
| 131 | + if isinstance(value, torch.Tensor): |
| 132 | + moved[key] = value.to(target_device) |
| 133 | + elif isinstance(value, tuple): |
| 134 | + moved[key] = tuple(v.to(target_device) for v in value) |
| 135 | + else: |
| 136 | + moved[key] = value |
| 137 | + return moved |
| 138 | + |
| 139 | + |
| 140 | +def decode_token(idx: int, processor, feature_key: str): |
| 141 | + if processor is None or not hasattr(processor, "code_vocab"): |
| 142 | + return str(idx) |
| 143 | + reverse_vocab = {index: token for token, index in processor.code_vocab.items()} |
| 144 | + token = reverse_vocab.get(idx, f"<UNK:{idx}>") |
| 145 | + |
| 146 | + if feature_key in {"conditions", "procedures"} and token not in {"<unk>", "<pad>"}: |
| 147 | + desc = ICD_CODE_TO_DESC.get(token) |
| 148 | + if desc: |
| 149 | + return f"{token}: {desc}" |
| 150 | + |
| 151 | + return token |
| 152 | + |
| 153 | + |
| 154 | +def unravel(flat_index: int, shape: torch.Size): |
| 155 | + coords = [] |
| 156 | + remaining = flat_index |
| 157 | + for dim in reversed(shape): |
| 158 | + coords.append(remaining % dim) |
| 159 | + remaining //= dim |
| 160 | + return list(reversed(coords)) |
| 161 | + |
| 162 | + |
| 163 | +def print_top_attributions( |
| 164 | + attributions, |
| 165 | + batch, |
| 166 | + processors, |
| 167 | + top_k: int = 10, |
| 168 | +): |
| 169 | + for feature_key, attr in attributions.items(): |
| 170 | + attr_cpu = attr.detach().cpu() |
| 171 | + if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: |
| 172 | + continue |
| 173 | + |
| 174 | + feature_input = batch[feature_key] |
| 175 | + if isinstance(feature_input, tuple): |
| 176 | + feature_input = feature_input[1] |
| 177 | + feature_input = feature_input.detach().cpu() |
| 178 | + |
| 179 | + flattened = attr_cpu[0].flatten() |
| 180 | + if flattened.numel() == 0: |
| 181 | + continue |
| 182 | + |
| 183 | + print(f"\nFeature: {feature_key}") |
| 184 | + k = min(top_k, flattened.numel()) |
| 185 | + top_values, top_indices = torch.topk(flattened.abs(), k=k) |
| 186 | + processor = processors.get(feature_key) if processors else None |
| 187 | + is_continuous = torch.is_floating_point(feature_input) |
| 188 | + |
| 189 | + for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1): |
| 190 | + attribution_value = flattened[flat_idx].item() |
| 191 | + coords = unravel(flat_idx.item(), attr_cpu[0].shape) |
| 192 | + |
| 193 | + if is_continuous: |
| 194 | + actual_value = feature_input[0][tuple(coords)].item() |
| 195 | + print( |
| 196 | + f" {rank:2d}. idx={coords} value={actual_value:.4f} " |
| 197 | + f"attr={attribution_value:+.6f}" |
| 198 | + ) |
| 199 | + else: |
| 200 | + token_idx = int(feature_input[0][tuple(coords)].item()) |
| 201 | + token = decode_token(token_idx, processor, feature_key) |
| 202 | + print( |
| 203 | + f" {rank:2d}. idx={coords} token='{token}' " |
| 204 | + f"attr={attribution_value:+.6f}" |
| 205 | + ) |
| 206 | + |
| 207 | + |
| 208 | +# %% Run GIM on a held-out sample |
| 209 | +gim = GIM(model, temperature=2.0) |
| 210 | +processors_for_display = sample_dataset.input_processors |
| 211 | + |
| 212 | +sample_batch = next(iter(test_loader)) |
| 213 | +sample_batch_device = move_batch_to_device(sample_batch, device) |
| 214 | + |
| 215 | +with torch.no_grad(): |
| 216 | + output = model(**sample_batch_device) |
| 217 | + probs = output["y_prob"] |
| 218 | + preds = torch.argmax(probs, dim=-1) |
| 219 | + label_key = model.label_key |
| 220 | + true_label = sample_batch_device[label_key] |
| 221 | + |
| 222 | + print("\nModel prediction for the sampled patient:") |
| 223 | + print(f" True label: {int(true_label.item())}") |
| 224 | + print(f" Predicted class: {int(preds.item())}") |
| 225 | + print(f" Probabilities: {probs[0].cpu().numpy()}") |
| 226 | + |
| 227 | +attributions = gim.attribute(**sample_batch_device) |
| 228 | +print_top_attributions(attributions, sample_batch_device, processors_for_display, top_k=10) |
| 229 | + |
| 230 | +# %% |
0 commit comments