@@ -130,6 +130,32 @@ class GIM(BaseInterpreter):
130130 (StageNet is currently supported).
131131 temperature: Softmax temperature used exclusively for the backward
132132 pass. A value of ``2.0`` matches the paper's best setting.
133+
134+ Examples:
135+ >>> import torch
136+ >>> from pyhealth.datasets import get_dataloader
137+ >>> from pyhealth.interpret.methods.gim import GIM
138+ >>> from pyhealth.models import StageNet
139+ >>>
140+ >>> # Assume ``sample_dataset`` and trained StageNet weights are available.
141+ >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
142+ >>> model = StageNet(dataset=sample_dataset, mode="binary")
143+ >>> model = model.to(device).eval()
144+ >>> test_loader = get_dataloader(sample_dataset, batch_size=1, shuffle=False)
145+ >>> gim = GIM(model, temperature=2.0)
146+ >>>
147+ >>> batch = next(iter(test_loader))
148+ >>> batch_device = {}
149+ >>> for key, value in batch.items():
150+ ... if isinstance(value, torch.Tensor):
151+ ... batch_device[key] = value.to(device)
152+ ... elif isinstance(value, tuple):
153+ ... batch_device[key] = tuple(v.to(device) for v in value)
154+ ... else:
155+ ... batch_device[key] = value
156+ >>>
157+ >>> attributions = gim.attribute(**batch_device)
158+ >>> print({k: v.shape for k, v in attributions.items()})
133159 """
134160
135161 def __init__ (
0 commit comments