Skip to content

Commit cf59e1b

Browse files
LogicFanjhnwu3
andauthored
[BMS] Gradient Interaction Modification (#605)
* Initial attempts for GIM * Add testcase and notebook * Fix incorrect implemenation, and add comments * Add docs * Examples in Docs * Adapt GIM for transformer --------- Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com>
1 parent 8eceb3a commit cf59e1b

File tree

8 files changed

+1132
-5
lines changed

8 files changed

+1132
-5
lines changed

docs/api/interpret.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Available Methods
4949
.. toctree::
5050
:maxdepth: 4
5151

52+
interpret/pyhealth.interpret.methods.gim
5253
interpret/pyhealth.interpret.methods.basic_gradient
5354
interpret/pyhealth.interpret.methods.chefer
5455
interpret/pyhealth.interpret.methods.deeplift
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
pyhealth.interpret.methods.gim
2+
================================
3+
4+
Overview
5+
--------
6+
7+
The Gradient Interaction Modifications (GIM) interpreter adapts the StageNet
8+
attribution method described by Edin et al. (2025). It recomputes softmax
9+
gradients with a higher temperature so that token-level interactions remain
10+
visible when cumulative softmax layers are present.
11+
12+
Use this interpreter with StageNet-style models that expose
13+
``forward_from_embedding`` and ``embedding_model``.
14+
15+
For a complete working example, see:
16+
``examples/gim_stagenet_mimic4.py``
17+
18+
API Reference
19+
-------------
20+
21+
.. autoclass:: pyhealth.interpret.methods.GIM
22+
:members:
23+
:undoc-members:
24+
:show-inheritance:
25+
:member-order: bysource

examples/gim_stagenet_mimic4.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# %% Loading MIMIC-IV dataset
2+
from pathlib import Path
3+
4+
import polars as pl
5+
import torch
6+
7+
from pyhealth.datasets import (
8+
MIMIC4EHRDataset,
9+
get_dataloader,
10+
load_processors,
11+
split_by_patient,
12+
)
13+
from pyhealth.interpret.methods import GIM
14+
from pyhealth.models import StageNet
15+
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4
16+
17+
# Configure dataset location and load cached processors
18+
dataset = MIMIC4EHRDataset(
19+
root="/home/logic/physionet.org/files/mimic-iv-demo/2.2/",
20+
tables=[
21+
"patients",
22+
"admissions",
23+
"diagnoses_icd",
24+
"procedures_icd",
25+
"labevents",
26+
],
27+
)
28+
29+
# %% Setting StageNet Mortality Prediction Task
30+
input_processors, output_processors = load_processors("../resources/")
31+
32+
sample_dataset = dataset.set_task(
33+
MortalityPredictionStageNetMIMIC4(),
34+
cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality",
35+
input_processors=input_processors,
36+
output_processors=output_processors,
37+
)
38+
print(f"Total samples: {len(sample_dataset)}")
39+
40+
41+
def load_icd_description_map(dataset_root: str) -> dict:
42+
"""Load ICD code → description mappings from reference tables."""
43+
mapping = {}
44+
root_path = Path(dataset_root).expanduser()
45+
diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz"
46+
proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz"
47+
48+
icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8}
49+
50+
if diag_path.exists():
51+
diag_df = pl.read_csv(
52+
diag_path,
53+
columns=["icd_code", "long_title"],
54+
dtypes=icd_dtype,
55+
)
56+
mapping.update(
57+
zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list())
58+
)
59+
60+
if proc_path.exists():
61+
proc_df = pl.read_csv(
62+
proc_path,
63+
columns=["icd_code", "long_title"],
64+
dtypes=icd_dtype,
65+
)
66+
mapping.update(
67+
zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list())
68+
)
69+
70+
return mapping
71+
72+
73+
ICD_CODE_TO_DESC = load_icd_description_map(dataset.root)
74+
75+
# %% Loading Pretrained StageNet model
76+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
77+
model = StageNet(
78+
dataset=sample_dataset,
79+
embedding_dim=128,
80+
chunk_size=128,
81+
levels=3,
82+
dropout=0.3,
83+
)
84+
85+
state_dict = torch.load("../resources/best.ckpt", map_location=device)
86+
model.load_state_dict(state_dict)
87+
model = model.to(device)
88+
model.eval()
89+
print(model)
90+
91+
# %% Preparing dataloaders
92+
_, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42)
93+
test_loader = get_dataloader(test_data, batch_size=1, shuffle=False)
94+
95+
96+
def move_batch_to_device(batch, target_device):
97+
moved = {}
98+
for key, value in batch.items():
99+
if isinstance(value, torch.Tensor):
100+
moved[key] = value.to(target_device)
101+
elif isinstance(value, tuple):
102+
moved[key] = tuple(v.to(target_device) for v in value)
103+
else:
104+
moved[key] = value
105+
return moved
106+
107+
108+
LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES
109+
110+
111+
def decode_token(idx: int, processor, feature_key: str):
112+
if processor is None or not hasattr(processor, "code_vocab"):
113+
return str(idx)
114+
reverse_vocab = {index: token for token, index in processor.code_vocab.items()}
115+
token = reverse_vocab.get(idx, f"<UNK:{idx}>")
116+
117+
if feature_key == "icd_codes" and token not in {"<unk>", "<pad>"}:
118+
desc = ICD_CODE_TO_DESC.get(token)
119+
if desc:
120+
return f"{token}: {desc}"
121+
122+
return token
123+
124+
125+
def unravel(flat_index: int, shape: torch.Size):
126+
coords = []
127+
remaining = flat_index
128+
for dim in reversed(shape):
129+
coords.append(remaining % dim)
130+
remaining //= dim
131+
return list(reversed(coords))
132+
133+
134+
def print_top_attributions(
135+
attributions,
136+
batch,
137+
processors,
138+
top_k: int = 10,
139+
):
140+
for feature_key, attr in attributions.items():
141+
attr_cpu = attr.detach().cpu()
142+
if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0:
143+
continue
144+
145+
feature_input = batch[feature_key]
146+
if isinstance(feature_input, tuple):
147+
feature_input = feature_input[1]
148+
feature_input = feature_input.detach().cpu()
149+
150+
flattened = attr_cpu[0].flatten()
151+
if flattened.numel() == 0:
152+
continue
153+
154+
print(f"\nFeature: {feature_key}")
155+
k = min(top_k, flattened.numel())
156+
top_values, top_indices = torch.topk(flattened.abs(), k=k)
157+
processor = processors.get(feature_key) if processors else None
158+
is_continuous = torch.is_floating_point(feature_input)
159+
160+
for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1):
161+
attribution_value = flattened[flat_idx].item()
162+
coords = unravel(flat_idx.item(), attr_cpu[0].shape)
163+
164+
if is_continuous:
165+
actual_value = feature_input[0][tuple(coords)].item()
166+
label = ""
167+
if feature_key == "labs" and len(coords) >= 1:
168+
lab_idx = coords[-1]
169+
if lab_idx < len(LAB_CATEGORY_NAMES):
170+
label = f"{LAB_CATEGORY_NAMES[lab_idx]} "
171+
print(
172+
f" {rank:2d}. idx={coords} {label}value={actual_value:.4f} "
173+
f"attr={attribution_value:+.6f}"
174+
)
175+
else:
176+
token_idx = int(feature_input[0][tuple(coords)].item())
177+
token = decode_token(token_idx, processor, feature_key)
178+
print(
179+
f" {rank:2d}. idx={coords} token='{token}' "
180+
f"attr={attribution_value:+.6f}"
181+
)
182+
183+
184+
# %% Run GIM on a held-out sample
185+
gim = GIM(model, temperature=2.0)
186+
187+
sample_batch = next(iter(test_loader))
188+
sample_batch_device = move_batch_to_device(sample_batch, device)
189+
190+
with torch.no_grad():
191+
output = model(**sample_batch_device)
192+
probs = output["y_prob"]
193+
preds = torch.argmax(probs, dim=-1)
194+
label_key = model.label_key
195+
true_label = sample_batch_device[label_key]
196+
197+
print("\nModel prediction for the sampled patient:")
198+
print(f" True label: {int(true_label.item())}")
199+
print(f" Predicted class: {int(preds.item())}")
200+
print(f" Probabilities: {probs[0].cpu().numpy()}")
201+
202+
attributions = gim.attribute(**sample_batch_device)
203+
print_top_attributions(attributions, sample_batch_device, input_processors, top_k=10)
204+
205+
# %%

0 commit comments

Comments
 (0)