Skip to content

Commit 06566fe

Browse files
committed
Adapt GIM for transformer
1 parent 9f9a592 commit 06566fe

File tree

3 files changed

+327
-7
lines changed

3 files changed

+327
-7
lines changed

examples/gim_transformer_mimic4.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# %% Loading MIMIC-IV dataset
2+
from pathlib import Path
3+
4+
import polars as pl
5+
import torch
6+
7+
from pyhealth.datasets import (
8+
MIMIC4EHRDataset,
9+
get_dataloader,
10+
load_processors,
11+
split_by_patient,
12+
)
13+
from pyhealth.interpret.methods import GIM
14+
from pyhealth.models import Transformer
15+
from pyhealth.tasks import MortalityPredictionMIMIC4
16+
17+
18+
def maybe_load_processors(resource_dir: str, task):
19+
"""Load cached processors if they match the current task schema."""
20+
21+
try:
22+
input_processors, output_processors = load_processors(resource_dir)
23+
except Exception as exc:
24+
print(f"Falling back to rebuilding processors: {exc}")
25+
return None, None
26+
27+
expected_inputs = set(task.input_schema.keys())
28+
expected_outputs = set(task.output_schema.keys())
29+
if set(input_processors.keys()) != expected_inputs:
30+
print(
31+
"Cached input processors do not match MortalityPredictionMIMIC4 schema; rebuilding."
32+
)
33+
return None, None
34+
if set(output_processors.keys()) != expected_outputs:
35+
print(
36+
"Cached output processors do not match MortalityPredictionMIMIC4 schema; rebuilding."
37+
)
38+
return None, None
39+
return input_processors, output_processors
40+
41+
42+
# Configure dataset location and optionally load cached processors
43+
dataset = MIMIC4EHRDataset(
44+
root="/home/logic/physionet.org/files/mimic-iv-demo/2.2/",
45+
tables=[
46+
"patients",
47+
"admissions",
48+
"diagnoses_icd",
49+
"procedures_icd",
50+
"prescriptions",
51+
],
52+
)
53+
54+
task = MortalityPredictionMIMIC4()
55+
input_processors, output_processors = maybe_load_processors("../resources/", task)
56+
57+
sample_dataset = dataset.set_task(
58+
task,
59+
cache_dir="~/.cache/pyhealth/mimic4_transformer_mortality",
60+
input_processors=input_processors,
61+
output_processors=output_processors,
62+
)
63+
print(f"Total samples: {len(sample_dataset)}")
64+
65+
66+
def load_icd_description_map(dataset_root: str) -> dict:
67+
"""Load ICD code → description mappings from reference tables."""
68+
69+
mapping = {}
70+
root_path = Path(dataset_root).expanduser()
71+
diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz"
72+
proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz"
73+
74+
icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8}
75+
76+
if diag_path.exists():
77+
diag_df = pl.read_csv(
78+
diag_path,
79+
columns=["icd_code", "long_title"],
80+
dtypes=icd_dtype,
81+
)
82+
mapping.update(
83+
zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list())
84+
)
85+
86+
if proc_path.exists():
87+
proc_df = pl.read_csv(
88+
proc_path,
89+
columns=["icd_code", "long_title"],
90+
dtypes=icd_dtype,
91+
)
92+
mapping.update(
93+
zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list())
94+
)
95+
96+
return mapping
97+
98+
99+
ICD_CODE_TO_DESC = load_icd_description_map(dataset.root)
100+
101+
# %% Loading Pretrained Transformer model
102+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
103+
model = Transformer(
104+
dataset=sample_dataset,
105+
embedding_dim=128,
106+
heads=2,
107+
dropout=0.2,
108+
num_layers=2,
109+
)
110+
111+
ckpt_path = Path("../resources/transformer_best.ckpt")
112+
if not ckpt_path.exists():
113+
raise FileNotFoundError(
114+
f"Missing pretrained weights at {ckpt_path}. "
115+
"Train the Transformer model and place the checkpoint in ../resources/."
116+
)
117+
state_dict = torch.load(str(ckpt_path), map_location=device)
118+
model.load_state_dict(state_dict)
119+
model = model.to(device)
120+
model.eval()
121+
print(model)
122+
123+
# %% Preparing dataloaders
124+
_, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42)
125+
test_loader = get_dataloader(test_data, batch_size=1, shuffle=False)
126+
127+
128+
def move_batch_to_device(batch, target_device):
129+
moved = {}
130+
for key, value in batch.items():
131+
if isinstance(value, torch.Tensor):
132+
moved[key] = value.to(target_device)
133+
elif isinstance(value, tuple):
134+
moved[key] = tuple(v.to(target_device) for v in value)
135+
else:
136+
moved[key] = value
137+
return moved
138+
139+
140+
def decode_token(idx: int, processor, feature_key: str):
141+
if processor is None or not hasattr(processor, "code_vocab"):
142+
return str(idx)
143+
reverse_vocab = {index: token for token, index in processor.code_vocab.items()}
144+
token = reverse_vocab.get(idx, f"<UNK:{idx}>")
145+
146+
if feature_key in {"conditions", "procedures"} and token not in {"<unk>", "<pad>"}:
147+
desc = ICD_CODE_TO_DESC.get(token)
148+
if desc:
149+
return f"{token}: {desc}"
150+
151+
return token
152+
153+
154+
def unravel(flat_index: int, shape: torch.Size):
155+
coords = []
156+
remaining = flat_index
157+
for dim in reversed(shape):
158+
coords.append(remaining % dim)
159+
remaining //= dim
160+
return list(reversed(coords))
161+
162+
163+
def print_top_attributions(
164+
attributions,
165+
batch,
166+
processors,
167+
top_k: int = 10,
168+
):
169+
for feature_key, attr in attributions.items():
170+
attr_cpu = attr.detach().cpu()
171+
if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0:
172+
continue
173+
174+
feature_input = batch[feature_key]
175+
if isinstance(feature_input, tuple):
176+
feature_input = feature_input[1]
177+
feature_input = feature_input.detach().cpu()
178+
179+
flattened = attr_cpu[0].flatten()
180+
if flattened.numel() == 0:
181+
continue
182+
183+
print(f"\nFeature: {feature_key}")
184+
k = min(top_k, flattened.numel())
185+
top_values, top_indices = torch.topk(flattened.abs(), k=k)
186+
processor = processors.get(feature_key) if processors else None
187+
is_continuous = torch.is_floating_point(feature_input)
188+
189+
for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1):
190+
attribution_value = flattened[flat_idx].item()
191+
coords = unravel(flat_idx.item(), attr_cpu[0].shape)
192+
193+
if is_continuous:
194+
actual_value = feature_input[0][tuple(coords)].item()
195+
print(
196+
f" {rank:2d}. idx={coords} value={actual_value:.4f} "
197+
f"attr={attribution_value:+.6f}"
198+
)
199+
else:
200+
token_idx = int(feature_input[0][tuple(coords)].item())
201+
token = decode_token(token_idx, processor, feature_key)
202+
print(
203+
f" {rank:2d}. idx={coords} token='{token}' "
204+
f"attr={attribution_value:+.6f}"
205+
)
206+
207+
208+
# %% Run GIM on a held-out sample
209+
gim = GIM(model, temperature=2.0)
210+
processors_for_display = sample_dataset.input_processors
211+
212+
sample_batch = next(iter(test_loader))
213+
sample_batch_device = move_batch_to_device(sample_batch, device)
214+
215+
with torch.no_grad():
216+
output = model(**sample_batch_device)
217+
probs = output["y_prob"]
218+
preds = torch.argmax(probs, dim=-1)
219+
label_key = model.label_key
220+
true_label = sample_batch_device[label_key]
221+
222+
print("\nModel prediction for the sampled patient:")
223+
print(f" True label: {int(true_label.item())}")
224+
print(f" Predicted class: {int(preds.item())}")
225+
print(f" Probabilities: {probs[0].cpu().numpy()}")
226+
227+
attributions = gim.attribute(**sample_batch_device)
228+
print_top_attributions(attributions, sample_batch_device, processors_for_display, top_k=10)
229+
230+
# %%

pyhealth/interpret/methods/gim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __exit__(self, exc_type, exc, exc_tb) -> bool:
105105

106106

107107
class GIM(BaseInterpreter):
108-
"""Gradient Interaction Modifications for StageNet-style models.
108+
"""Gradient Interaction Modifications for StageNet-style and Transformer models.
109109
110110
This interpreter adapts the Gradient Interaction Modifications (GIM)
111111
technique (Edin et al., 2025) to PyHealth, focusing on StageNet where
@@ -187,14 +187,14 @@ def attribute(
187187
# Clear stale gradients before the attribution pass.
188188
self.model.zero_grad(set_to_none=True)
189189

190-
time_kwarg = time_info if time_info else None
191190
# Step 1 (TSG): install the temperature-adjusted softmax hooks so all
192191
# backward passes through StageNet's cumax operations use the higher τ.
193192
with _GIMHookContext(self.model, self.temperature):
194193
forward_kwargs = {**label_data} if label_data else {}
194+
if time_info:
195+
forward_kwargs["time_info"] = time_info
195196
output = self.model.forward_from_embedding(
196197
feature_embeddings=embeddings,
197-
time_info=time_kwarg,
198198
**forward_kwargs,
199199
)
200200

0 commit comments

Comments
 (0)