Skip to content

Commit 6e6192f

Browse files
committed
Fix incorrect implemenation, and add comments
1 parent 5097228 commit 6e6192f

File tree

2 files changed

+33
-26
lines changed

2 files changed

+33
-26
lines changed

pyhealth/interpret/methods/gim.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212

1313
class _TemperatureSoftmax(torch.autograd.Function):
14-
"""Custom autograd op implementing temperature-adjusted softmax gradients."""
14+
"""Custom autograd op implementing temperature-adjusted softmax gradients.
15+
16+
Implements the Temperature-Scaled Gradients (TSG) rule from GIM Sec. 4.1 by
17+
recomputing the backward Jacobian with a higher temperature while leaving
18+
the forward softmax unchanged.
19+
"""
1520

1621
@staticmethod
1722
def forward(
@@ -63,7 +68,12 @@ def apply(self, name: str, tensor: torch.Tensor, **kwargs) -> torch.Tensor:
6368

6469

6570
class _GIMHookContext(contextlib.AbstractContextManager):
66-
"""Context manager that wires GIM hooks if the model supports them."""
71+
"""Context manager that wires GIM hooks if the model supports them.
72+
73+
TSG needs to intercept every activation that calls ``torch.softmax``.
74+
StageNet exposes DeepLIFT-style hook setters, so we reuse that surface
75+
unless a dedicated ``set_gim_hooks`` is provided.
76+
"""
6777

6878
def __init__(self, model: BaseModel, temperature: float):
6979
self.model = model
@@ -108,25 +118,24 @@ class GIM(BaseInterpreter):
108118
softmax redistribution.
109119
2. **LayerNorm freeze:** Layer normalization parameters are treated as
110120
constants during backpropagation. StageNet does not employ layer norm,
111-
so this step is effectively a no-op but kept for API parity.
112-
3. **Gradient normalization:** Gradients are reported as gradient-timesinput and collapsed onto the original token axes, ensuring consistent scale
113-
across visits. This mirrors the normalization heuristic proposed in
114-
the paper for multiplicative interactions.
121+
so this rule becomes a mathematical no-op, matching the paper when
122+
σ is constant.
123+
3. **Gradient normalization:** When no multiplicative fan-in exists (as in
124+
StageNet’s embedding → recurrent pipeline) the uniform division rule
125+
effectively multiplies by 1, so propagating raw gradients remains
126+
faithful to Section 4.2.
115127
116128
Args:
117129
model: Trained PyHealth model supporting ``forward_from_embedding``
118130
(StageNet is currently supported).
119131
temperature: Softmax temperature used exclusively for the backward
120132
pass. A value of ``2.0`` matches the paper's best setting.
121-
multiply_by_input: Whether to return gradient×input (default) or raw
122-
gradients in embedding space.
123133
"""
124134

125135
def __init__(
126136
self,
127137
model: BaseModel,
128138
temperature: float = 2.0,
129-
multiply_by_input: bool = True,
130139
):
131140
super().__init__(model)
132141
if not hasattr(model, "forward_from_embedding"):
@@ -138,7 +147,6 @@ def __init__(
138147
"GIM requires access to the model's embedding_model."
139148
)
140149
self.temperature = max(float(temperature), 1.0)
141-
self.multiply_by_input = multiply_by_input
142150

143151
def attribute(
144152
self,
@@ -154,6 +162,8 @@ def attribute(
154162
self.model.zero_grad(set_to_none=True)
155163

156164
time_kwarg = time_info if time_info else None
165+
# Step 1 (TSG): install the temperature-adjusted softmax hooks so all
166+
# backward passes through StageNet's cumax operations use the higher τ.
157167
with _GIMHookContext(self.model, self.temperature):
158168
forward_kwargs = {**label_data} if label_data else {}
159169
output = self.model.forward_from_embedding(
@@ -165,6 +175,9 @@ def attribute(
165175
logits = output["logit"]
166176
target = self._compute_target_output(logits, target_class_idx)
167177

178+
# Step 2 (LayerNorm freeze): StageNet does not contain layer norms, so
179+
# there are no σ parameters to freeze; the reset below ensures any
180+
# hypothetical normalization buffers would stay constant as in Sec. 4.2.
168181
self.model.zero_grad(set_to_none=True)
169182
for emb in embeddings.values():
170183
if emb.grad is not None:
@@ -177,10 +190,10 @@ def attribute(
177190
grad = emb.grad
178191
if grad is None:
179192
grad = torch.zeros_like(emb)
180-
attr = grad
181-
if self.multiply_by_input:
182-
attr = attr * emb
183-
token_attr = self._collapse_to_input_shape(attr, input_shapes[key])
193+
# Step 3 (Gradient normalization): StageNet lacks the multi-input
194+
# products targeted by the uniform rule, so dividing by 1 (identity)
195+
# yields the same gradients the paper would propagate.
196+
token_attr = self._collapse_to_input_shape(grad, input_shapes[key])
184197
attributions[key] = token_attr.detach()
185198

186199
return attributions

tests/core/test_gim.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def _manual_token_attribution(
127127
model: _ToyGIMModel,
128128
tokens: torch.Tensor,
129129
labels: torch.Tensor,
130-
multiply_by_input: bool = True,
131130
) -> torch.Tensor:
132131
"""Reference implementation mimicking GIM without temperature scaling."""
133132

@@ -145,8 +144,6 @@ def _manual_token_attribution(
145144
target.backward()
146145

147146
grad = embeddings.grad.detach()
148-
if multiply_by_input:
149-
grad = grad * embeddings.detach()
150147
token_attr = grad.sum(dim=-1)
151148
return token_attr
152149

@@ -160,7 +157,7 @@ def setUp(self):
160157
self.labels = torch.zeros((1, 1))
161158

162159
def test_matches_manual_gradient_when_temperature_one(self):
163-
"""Temperature=1 should collapse to plain gradient×input."""
160+
"""Temperature=1 should collapse to plain gradients."""
164161

165162
model = _ToyGIMModel()
166163
gim = GIM(model, temperature=1.0)
@@ -204,17 +201,14 @@ def test_prefers_custom_gim_hooks(self):
204201
self.assertEqual(model.gim_hook_calls, 1)
205202
self.assertEqual(model.deeplift_hook_calls, 0)
206203

207-
def test_disable_multiply_by_input_returns_raw_gradient(self):
208-
"""Setting multiply_by_input=False should return pure gradients."""
204+
def test_attributions_match_input_shape(self):
205+
"""Collapsed gradients should align with the token tensor shape."""
209206

210207
model = _ToyGIMModel()
211-
gim = GIM(model, temperature=1.0, multiply_by_input=False)
208+
gim = GIM(model, temperature=1.0)
212209

213-
attrs = gim.attribute(target_class_idx=0, codes=self.tokens, label=self.labels)["codes"]
214-
manual_grad = _manual_token_attribution(
215-
model, self.tokens, self.labels, multiply_by_input=False
216-
)
217-
torch.testing.assert_close(attrs, manual_grad, atol=1e-6, rtol=1e-5)
210+
attrs = gim.attribute(target_class_idx=0, codes=self.tokens, label=self.labels)
211+
self.assertEqual(tuple(attrs["codes"].shape), tuple(self.tokens.shape))
218212

219213
def test_handles_temporal_tuple_inputs(self):
220214
"""StageNet-style (time, value) tuples should be processed seamlessly."""

0 commit comments

Comments
 (0)