Skip to content

Commit 8dd0248

Browse files
committed
Examples in Docs
1 parent 61687dd commit 8dd0248

File tree

1 file changed

+26
-0
lines changed
  • pyhealth/interpret/methods

1 file changed

+26
-0
lines changed

pyhealth/interpret/methods/gim.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,32 @@ class GIM(BaseInterpreter):
130130
(StageNet is currently supported).
131131
temperature: Softmax temperature used exclusively for the backward
132132
pass. A value of ``2.0`` matches the paper's best setting.
133+
134+
Examples:
135+
>>> import torch
136+
>>> from pyhealth.datasets import get_dataloader
137+
>>> from pyhealth.interpret.methods.gim import GIM
138+
>>> from pyhealth.models import StageNet
139+
>>>
140+
>>> # Assume ``sample_dataset`` and trained StageNet weights are available.
141+
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
142+
>>> model = StageNet(dataset=sample_dataset, mode="binary")
143+
>>> model = model.to(device).eval()
144+
>>> test_loader = get_dataloader(sample_dataset, batch_size=1, shuffle=False)
145+
>>> gim = GIM(model, temperature=2.0)
146+
>>>
147+
>>> batch = next(iter(test_loader))
148+
>>> batch_device = {}
149+
>>> for key, value in batch.items():
150+
... if isinstance(value, torch.Tensor):
151+
... batch_device[key] = value.to(device)
152+
... elif isinstance(value, tuple):
153+
... batch_device[key] = tuple(v.to(device) for v in value)
154+
... else:
155+
... batch_device[key] = value
156+
>>>
157+
>>> attributions = gim.attribute(**batch_device)
158+
>>> print({k: v.shape for k, v in attributions.items()})
133159
"""
134160

135161
def __init__(

0 commit comments

Comments
 (0)