1111
1212
1313class _TemperatureSoftmax (torch .autograd .Function ):
14- """Custom autograd op implementing temperature-adjusted softmax gradients."""
14+ """Custom autograd op implementing temperature-adjusted softmax gradients.
15+
16+ Implements the Temperature-Scaled Gradients (TSG) rule from GIM Sec. 4.1 by
17+ recomputing the backward Jacobian with a higher temperature while leaving
18+ the forward softmax unchanged.
19+ """
1520
1621 @staticmethod
1722 def forward (
@@ -63,7 +68,12 @@ def apply(self, name: str, tensor: torch.Tensor, **kwargs) -> torch.Tensor:
6368
6469
6570class _GIMHookContext (contextlib .AbstractContextManager ):
66- """Context manager that wires GIM hooks if the model supports them."""
71+ """Context manager that wires GIM hooks if the model supports them.
72+
73+ TSG needs to intercept every activation that calls ``torch.softmax``.
74+ StageNet exposes DeepLIFT-style hook setters, so we reuse that surface
75+ unless a dedicated ``set_gim_hooks`` is provided.
76+ """
6777
6878 def __init__ (self , model : BaseModel , temperature : float ):
6979 self .model = model
@@ -108,25 +118,24 @@ class GIM(BaseInterpreter):
108118 softmax redistribution.
109119 2. **LayerNorm freeze:** Layer normalization parameters are treated as
110120 constants during backpropagation. StageNet does not employ layer norm,
111- so this step is effectively a no-op but kept for API parity.
112- 3. **Gradient normalization:** Gradients are reported as gradient-timesinput and collapsed onto the original token axes, ensuring consistent scale
113- across visits. This mirrors the normalization heuristic proposed in
114- the paper for multiplicative interactions.
121+ so this rule becomes a mathematical no-op, matching the paper when
122+ σ is constant.
123+ 3. **Gradient normalization:** When no multiplicative fan-in exists (as in
124+ StageNet’s embedding → recurrent pipeline) the uniform division rule
125+ effectively multiplies by 1, so propagating raw gradients remains
126+ faithful to Section 4.2.
115127
116128 Args:
117129 model: Trained PyHealth model supporting ``forward_from_embedding``
118130 (StageNet is currently supported).
119131 temperature: Softmax temperature used exclusively for the backward
120132 pass. A value of ``2.0`` matches the paper's best setting.
121- multiply_by_input: Whether to return gradient×input (default) or raw
122- gradients in embedding space.
123133 """
124134
125135 def __init__ (
126136 self ,
127137 model : BaseModel ,
128138 temperature : float = 2.0 ,
129- multiply_by_input : bool = True ,
130139 ):
131140 super ().__init__ (model )
132141 if not hasattr (model , "forward_from_embedding" ):
@@ -138,7 +147,6 @@ def __init__(
138147 "GIM requires access to the model's embedding_model."
139148 )
140149 self .temperature = max (float (temperature ), 1.0 )
141- self .multiply_by_input = multiply_by_input
142150
143151 def attribute (
144152 self ,
@@ -154,6 +162,8 @@ def attribute(
154162 self .model .zero_grad (set_to_none = True )
155163
156164 time_kwarg = time_info if time_info else None
165+ # Step 1 (TSG): install the temperature-adjusted softmax hooks so all
166+ # backward passes through StageNet's cumax operations use the higher τ.
157167 with _GIMHookContext (self .model , self .temperature ):
158168 forward_kwargs = {** label_data } if label_data else {}
159169 output = self .model .forward_from_embedding (
@@ -165,6 +175,9 @@ def attribute(
165175 logits = output ["logit" ]
166176 target = self ._compute_target_output (logits , target_class_idx )
167177
178+ # Step 2 (LayerNorm freeze): StageNet does not contain layer norms, so
179+ # there are no σ parameters to freeze; the reset below ensures any
180+ # hypothetical normalization buffers would stay constant as in Sec. 4.2.
168181 self .model .zero_grad (set_to_none = True )
169182 for emb in embeddings .values ():
170183 if emb .grad is not None :
@@ -177,10 +190,10 @@ def attribute(
177190 grad = emb .grad
178191 if grad is None :
179192 grad = torch .zeros_like (emb )
180- attr = grad
181- if self . multiply_by_input :
182- attr = attr * emb
183- token_attr = self ._collapse_to_input_shape (attr , input_shapes [key ])
193+ # Step 3 (Gradient normalization): StageNet lacks the multi-input
194+ # products targeted by the uniform rule, so dividing by 1 (identity)
195+ # yields the same gradients the paper would propagate.
196+ token_attr = self ._collapse_to_input_shape (grad , input_shapes [key ])
184197 attributions [key ] = token_attr .detach ()
185198
186199 return attributions
0 commit comments