diff --git a/docs/api/interpret.rst b/docs/api/interpret.rst index 37c5a07a2..096405fa5 100644 --- a/docs/api/interpret.rst +++ b/docs/api/interpret.rst @@ -49,6 +49,7 @@ Available Methods .. toctree:: :maxdepth: 4 + interpret/pyhealth.interpret.methods.gim interpret/pyhealth.interpret.methods.basic_gradient interpret/pyhealth.interpret.methods.chefer interpret/pyhealth.interpret.methods.deeplift diff --git a/docs/api/interpret/pyhealth.interpret.methods.gim.rst b/docs/api/interpret/pyhealth.interpret.methods.gim.rst new file mode 100644 index 000000000..66b78c43b --- /dev/null +++ b/docs/api/interpret/pyhealth.interpret.methods.gim.rst @@ -0,0 +1,25 @@ +pyhealth.interpret.methods.gim +================================ + +Overview +-------- + +The Gradient Interaction Modifications (GIM) interpreter adapts the StageNet +attribution method described by Edin et al. (2025). It recomputes softmax +gradients with a higher temperature so that token-level interactions remain +visible when cumulative softmax layers are present. + +Use this interpreter with StageNet-style models that expose +``forward_from_embedding`` and ``embedding_model``. + +For a complete working example, see: +``examples/gim_stagenet_mimic4.py`` + +API Reference +------------- + +.. autoclass:: pyhealth.interpret.methods.GIM + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/examples/gim_stagenet_mimic4.py b/examples/gim_stagenet_mimic4.py new file mode 100644 index 000000000..8d45ad9a7 --- /dev/null +++ b/examples/gim_stagenet_mimic4.py @@ -0,0 +1,205 @@ +# %% Loading MIMIC-IV dataset +from pathlib import Path + +import polars as pl +import torch + +from pyhealth.datasets import ( + MIMIC4EHRDataset, + get_dataloader, + load_processors, + split_by_patient, +) +from pyhealth.interpret.methods import GIM +from pyhealth.models import StageNet +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + +# Configure dataset location and load cached processors +dataset = MIMIC4EHRDataset( + root="/home/logic/physionet.org/files/mimic-iv-demo/2.2/", + tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], +) + +# %% Setting StageNet Mortality Prediction Task +input_processors, output_processors = load_processors("../resources/") + +sample_dataset = dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", + input_processors=input_processors, + output_processors=output_processors, +) +print(f"Total samples: {len(sample_dataset)}") + + +def load_icd_description_map(dataset_root: str) -> dict: + """Load ICD code → description mappings from reference tables.""" + mapping = {} + root_path = Path(dataset_root).expanduser() + diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz" + proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz" + + icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8} + + if diag_path.exists(): + diag_df = pl.read_csv( + diag_path, + columns=["icd_code", "long_title"], + dtypes=icd_dtype, + ) + mapping.update( + zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list()) + ) + + if proc_path.exists(): + proc_df = pl.read_csv( + proc_path, + columns=["icd_code", "long_title"], + dtypes=icd_dtype, + ) + mapping.update( + zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list()) + ) + + return mapping + + +ICD_CODE_TO_DESC = load_icd_description_map(dataset.root) + +# %% Loading Pretrained StageNet model +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, +) + +state_dict = torch.load("../resources/best.ckpt", map_location=device) +model.load_state_dict(state_dict) +model = model.to(device) +model.eval() +print(model) + +# %% Preparing dataloaders +_, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42) +test_loader = get_dataloader(test_data, batch_size=1, shuffle=False) + + +def move_batch_to_device(batch, target_device): + moved = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + moved[key] = value.to(target_device) + elif isinstance(value, tuple): + moved[key] = tuple(v.to(target_device) for v in value) + else: + moved[key] = value + return moved + + +LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES + + +def decode_token(idx: int, processor, feature_key: str): + if processor is None or not hasattr(processor, "code_vocab"): + return str(idx) + reverse_vocab = {index: token for token, index in processor.code_vocab.items()} + token = reverse_vocab.get(idx, f"") + + if feature_key == "icd_codes" and token not in {"", ""}: + desc = ICD_CODE_TO_DESC.get(token) + if desc: + return f"{token}: {desc}" + + return token + + +def unravel(flat_index: int, shape: torch.Size): + coords = [] + remaining = flat_index + for dim in reversed(shape): + coords.append(remaining % dim) + remaining //= dim + return list(reversed(coords)) + + +def print_top_attributions( + attributions, + batch, + processors, + top_k: int = 10, +): + for feature_key, attr in attributions.items(): + attr_cpu = attr.detach().cpu() + if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: + continue + + feature_input = batch[feature_key] + if isinstance(feature_input, tuple): + feature_input = feature_input[1] + feature_input = feature_input.detach().cpu() + + flattened = attr_cpu[0].flatten() + if flattened.numel() == 0: + continue + + print(f"\nFeature: {feature_key}") + k = min(top_k, flattened.numel()) + top_values, top_indices = torch.topk(flattened.abs(), k=k) + processor = processors.get(feature_key) if processors else None + is_continuous = torch.is_floating_point(feature_input) + + for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1): + attribution_value = flattened[flat_idx].item() + coords = unravel(flat_idx.item(), attr_cpu[0].shape) + + if is_continuous: + actual_value = feature_input[0][tuple(coords)].item() + label = "" + if feature_key == "labs" and len(coords) >= 1: + lab_idx = coords[-1] + if lab_idx < len(LAB_CATEGORY_NAMES): + label = f"{LAB_CATEGORY_NAMES[lab_idx]} " + print( + f" {rank:2d}. idx={coords} {label}value={actual_value:.4f} " + f"attr={attribution_value:+.6f}" + ) + else: + token_idx = int(feature_input[0][tuple(coords)].item()) + token = decode_token(token_idx, processor, feature_key) + print( + f" {rank:2d}. idx={coords} token='{token}' " + f"attr={attribution_value:+.6f}" + ) + + +# %% Run GIM on a held-out sample +gim = GIM(model, temperature=2.0) + +sample_batch = next(iter(test_loader)) +sample_batch_device = move_batch_to_device(sample_batch, device) + +with torch.no_grad(): + output = model(**sample_batch_device) + probs = output["y_prob"] + preds = torch.argmax(probs, dim=-1) + label_key = model.label_key + true_label = sample_batch_device[label_key] + + print("\nModel prediction for the sampled patient:") + print(f" True label: {int(true_label.item())}") + print(f" Predicted class: {int(preds.item())}") + print(f" Probabilities: {probs[0].cpu().numpy()}") + +attributions = gim.attribute(**sample_batch_device) +print_top_attributions(attributions, sample_batch_device, input_processors, top_k=10) + +# %% diff --git a/examples/gim_transformer_mimic4.py b/examples/gim_transformer_mimic4.py new file mode 100644 index 000000000..374d00d64 --- /dev/null +++ b/examples/gim_transformer_mimic4.py @@ -0,0 +1,230 @@ +# %% Loading MIMIC-IV dataset +from pathlib import Path + +import polars as pl +import torch + +from pyhealth.datasets import ( + MIMIC4EHRDataset, + get_dataloader, + load_processors, + split_by_patient, +) +from pyhealth.interpret.methods import GIM +from pyhealth.models import Transformer +from pyhealth.tasks import MortalityPredictionMIMIC4 + + +def maybe_load_processors(resource_dir: str, task): + """Load cached processors if they match the current task schema.""" + + try: + input_processors, output_processors = load_processors(resource_dir) + except Exception as exc: + print(f"Falling back to rebuilding processors: {exc}") + return None, None + + expected_inputs = set(task.input_schema.keys()) + expected_outputs = set(task.output_schema.keys()) + if set(input_processors.keys()) != expected_inputs: + print( + "Cached input processors do not match MortalityPredictionMIMIC4 schema; rebuilding." + ) + return None, None + if set(output_processors.keys()) != expected_outputs: + print( + "Cached output processors do not match MortalityPredictionMIMIC4 schema; rebuilding." + ) + return None, None + return input_processors, output_processors + + +# Configure dataset location and optionally load cached processors +dataset = MIMIC4EHRDataset( + root="/home/logic/physionet.org/files/mimic-iv-demo/2.2/", + tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "prescriptions", + ], +) + +task = MortalityPredictionMIMIC4() +input_processors, output_processors = maybe_load_processors("../resources/", task) + +sample_dataset = dataset.set_task( + task, + cache_dir="~/.cache/pyhealth/mimic4_transformer_mortality", + input_processors=input_processors, + output_processors=output_processors, +) +print(f"Total samples: {len(sample_dataset)}") + + +def load_icd_description_map(dataset_root: str) -> dict: + """Load ICD code → description mappings from reference tables.""" + + mapping = {} + root_path = Path(dataset_root).expanduser() + diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz" + proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz" + + icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8} + + if diag_path.exists(): + diag_df = pl.read_csv( + diag_path, + columns=["icd_code", "long_title"], + dtypes=icd_dtype, + ) + mapping.update( + zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list()) + ) + + if proc_path.exists(): + proc_df = pl.read_csv( + proc_path, + columns=["icd_code", "long_title"], + dtypes=icd_dtype, + ) + mapping.update( + zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list()) + ) + + return mapping + + +ICD_CODE_TO_DESC = load_icd_description_map(dataset.root) + +# %% Loading Pretrained Transformer model +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model = Transformer( + dataset=sample_dataset, + embedding_dim=128, + heads=2, + dropout=0.2, + num_layers=2, +) + +ckpt_path = Path("../resources/transformer_best.ckpt") +if not ckpt_path.exists(): + raise FileNotFoundError( + f"Missing pretrained weights at {ckpt_path}. " + "Train the Transformer model and place the checkpoint in ../resources/." + ) +state_dict = torch.load(str(ckpt_path), map_location=device) +model.load_state_dict(state_dict) +model = model.to(device) +model.eval() +print(model) + +# %% Preparing dataloaders +_, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42) +test_loader = get_dataloader(test_data, batch_size=1, shuffle=False) + + +def move_batch_to_device(batch, target_device): + moved = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + moved[key] = value.to(target_device) + elif isinstance(value, tuple): + moved[key] = tuple(v.to(target_device) for v in value) + else: + moved[key] = value + return moved + + +def decode_token(idx: int, processor, feature_key: str): + if processor is None or not hasattr(processor, "code_vocab"): + return str(idx) + reverse_vocab = {index: token for token, index in processor.code_vocab.items()} + token = reverse_vocab.get(idx, f"") + + if feature_key in {"conditions", "procedures"} and token not in {"", ""}: + desc = ICD_CODE_TO_DESC.get(token) + if desc: + return f"{token}: {desc}" + + return token + + +def unravel(flat_index: int, shape: torch.Size): + coords = [] + remaining = flat_index + for dim in reversed(shape): + coords.append(remaining % dim) + remaining //= dim + return list(reversed(coords)) + + +def print_top_attributions( + attributions, + batch, + processors, + top_k: int = 10, +): + for feature_key, attr in attributions.items(): + attr_cpu = attr.detach().cpu() + if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: + continue + + feature_input = batch[feature_key] + if isinstance(feature_input, tuple): + feature_input = feature_input[1] + feature_input = feature_input.detach().cpu() + + flattened = attr_cpu[0].flatten() + if flattened.numel() == 0: + continue + + print(f"\nFeature: {feature_key}") + k = min(top_k, flattened.numel()) + top_values, top_indices = torch.topk(flattened.abs(), k=k) + processor = processors.get(feature_key) if processors else None + is_continuous = torch.is_floating_point(feature_input) + + for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1): + attribution_value = flattened[flat_idx].item() + coords = unravel(flat_idx.item(), attr_cpu[0].shape) + + if is_continuous: + actual_value = feature_input[0][tuple(coords)].item() + print( + f" {rank:2d}. idx={coords} value={actual_value:.4f} " + f"attr={attribution_value:+.6f}" + ) + else: + token_idx = int(feature_input[0][tuple(coords)].item()) + token = decode_token(token_idx, processor, feature_key) + print( + f" {rank:2d}. idx={coords} token='{token}' " + f"attr={attribution_value:+.6f}" + ) + + +# %% Run GIM on a held-out sample +gim = GIM(model, temperature=2.0) +processors_for_display = sample_dataset.input_processors + +sample_batch = next(iter(test_loader)) +sample_batch_device = move_batch_to_device(sample_batch, device) + +with torch.no_grad(): + output = model(**sample_batch_device) + probs = output["y_prob"] + preds = torch.argmax(probs, dim=-1) + label_key = model.label_key + true_label = sample_batch_device[label_key] + + print("\nModel prediction for the sampled patient:") + print(f" True label: {int(true_label.item())}") + print(f" Predicted class: {int(preds.item())}") + print(f" Probabilities: {probs[0].cpu().numpy()}") + +attributions = gim.attribute(**sample_batch_device) +print_top_attributions(attributions, sample_batch_device, processors_for_display, top_k=10) + +# %% diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index c877e372c..c6b6e461d 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -2,6 +2,14 @@ from pyhealth.interpret.methods.chefer import CheferRelevance from pyhealth.interpret.methods.basic_gradient import BasicGradientSaliencyMaps from pyhealth.interpret.methods.deeplift import DeepLift +from pyhealth.interpret.methods.gim import GIM from pyhealth.interpret.methods.integrated_gradients import IntegratedGradients -__all__ = ["BaseInterpreter", "BasicGradientSaliencyMaps", "CheferRelevance", "DeepLift", "IntegratedGradients"] \ No newline at end of file +__all__ = [ + "BaseInterpreter", + "CheferRelevance", + "DeepLift", + "GIM", + "IntegratedGradients", +] +__all__ = ["BaseInterpreter", "BasicGradientSaliencyMaps", "CheferRelevance", "DeepLift", "IntegratedGradients"] diff --git a/pyhealth/interpret/methods/gim.py b/pyhealth/interpret/methods/gim.py new file mode 100644 index 000000000..f3ddb7fd7 --- /dev/null +++ b/pyhealth/interpret/methods/gim.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +import contextlib +from typing import Dict, Optional, Tuple + +import torch + +from pyhealth.models import BaseModel + +from .base_interpreter import BaseInterpreter + + +class _TemperatureSoftmax(torch.autograd.Function): + """Custom autograd op implementing temperature-adjusted softmax gradients. + + Implements the Temperature-Scaled Gradients (TSG) rule from GIM Sec. 4.1 by + recomputing the backward Jacobian with a higher temperature while leaving + the forward softmax unchanged. + """ + + @staticmethod + def forward( + ctx, + input_tensor: torch.Tensor, + dim: int, + temperature: float, + ) -> torch.Tensor: + ctx.dim = dim + ctx.temperature = float(temperature) + ctx.save_for_backward(input_tensor) + return torch.softmax(input_tensor, dim=dim) + + @staticmethod + def backward( + ctx, + grad_output: torch.Tensor, + ) -> Tuple[torch.Tensor, None, None]: + (input_tensor,) = ctx.saved_tensors + dim = ctx.dim + temperature = max(ctx.temperature, 1.0) + + if temperature == 1.0: + probs = torch.softmax(input_tensor, dim=dim) + dot = (grad_output * probs).sum(dim=dim, keepdim=True) + grad_input = probs * (grad_output - dot) + return grad_input, None, None + + adjusted = torch.softmax(input_tensor / temperature, dim=dim) + dot = (grad_output * adjusted).sum(dim=dim, keepdim=True) + grad_input = adjusted * (grad_output - dot) + grad_input = grad_input / temperature + return grad_input, None, None + + +class _GIMActivationHooks: + """Router that swaps selected activations for GIM-aware variants.""" + + def __init__(self, temperature: float = 2.0): + self.temperature = temperature + + def apply(self, name: str, tensor: torch.Tensor, **kwargs) -> torch.Tensor: + if name == "softmax" and self.temperature is not None: + dim = kwargs.get("dim", -1) + temp = max(float(self.temperature), 1.0) + return _TemperatureSoftmax.apply(tensor, dim, temp) + fn = getattr(torch, name) + return fn(tensor, **kwargs) + + +class _GIMHookContext(contextlib.AbstractContextManager): + """Context manager that wires GIM hooks if the model supports them. + + TSG needs to intercept every activation that calls ``torch.softmax``. + StageNet exposes DeepLIFT-style hook setters, so we reuse that surface + unless a dedicated ``set_gim_hooks`` is provided. + """ + + def __init__(self, model: BaseModel, temperature: float): + self.model = model + self.temperature = temperature + self.hooks: Optional[_GIMActivationHooks] = None + self._set_fn = None + self._clear_fn = None + + # Prefer explicit GIM hooks if the model exposes them, otherwise + # reuse the DeepLIFT hook wiring which StageNet already supports. + if hasattr(model, "set_gim_hooks") and hasattr(model, "clear_gim_hooks"): + self._set_fn = model.set_gim_hooks + self._clear_fn = model.clear_gim_hooks + elif hasattr(model, "set_deeplift_hooks") and hasattr(model, "clear_deeplift_hooks"): + self._set_fn = model.set_deeplift_hooks + self._clear_fn = model.clear_deeplift_hooks + + def __enter__(self) -> "_GIMHookContext": + if self._set_fn is not None and self.temperature > 1.0: + self.hooks = _GIMActivationHooks(temperature=self.temperature) + self._set_fn(self.hooks) + return self + + def __exit__(self, exc_type, exc, exc_tb) -> bool: + if self._clear_fn is not None and self.hooks is not None: + self._clear_fn() + self.hooks = None + return False + + +class GIM(BaseInterpreter): + """Gradient Interaction Modifications for StageNet-style and Transformer models. + + This interpreter adapts the Gradient Interaction Modifications (GIM) + technique (Edin et al., 2025) to PyHealth, focusing on StageNet where + cumulative softmax operations can exhibit self-repair. The implementation + follows three high-level ideas from the paper: + + 1. **Temperature-adjusted softmax gradients (TSG):** Backpropagated + gradients through cumulative softmax are recomputed with a higher + temperature, exposing interactions that are otherwise hidden by + softmax redistribution. + 2. **LayerNorm freeze:** Layer normalization parameters are treated as + constants during backpropagation. StageNet does not employ layer norm, + so this rule becomes a mathematical no-op, matching the paper when + σ is constant. + 3. **Gradient normalization:** When no multiplicative fan-in exists (as in + StageNet’s embedding → recurrent pipeline) the uniform division rule + effectively multiplies by 1, so propagating raw gradients remains + faithful to Section 4.2. + + Args: + model: Trained PyHealth model supporting ``forward_from_embedding`` + (StageNet is currently supported). + temperature: Softmax temperature used exclusively for the backward + pass. A value of ``2.0`` matches the paper's best setting. + + Examples: + >>> import torch + >>> from pyhealth.datasets import get_dataloader + >>> from pyhealth.interpret.methods.gim import GIM + >>> from pyhealth.models import StageNet + >>> + >>> # Assume ``sample_dataset`` and trained StageNet weights are available. + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + >>> model = StageNet(dataset=sample_dataset, mode="binary") + >>> model = model.to(device).eval() + >>> test_loader = get_dataloader(sample_dataset, batch_size=1, shuffle=False) + >>> gim = GIM(model, temperature=2.0) + >>> + >>> batch = next(iter(test_loader)) + >>> batch_device = {} + >>> for key, value in batch.items(): + ... if isinstance(value, torch.Tensor): + ... batch_device[key] = value.to(device) + ... elif isinstance(value, tuple): + ... batch_device[key] = tuple(v.to(device) for v in value) + ... else: + ... batch_device[key] = value + >>> + >>> attributions = gim.attribute(**batch_device) + >>> print({k: v.shape for k, v in attributions.items()}) + """ + + def __init__( + self, + model: BaseModel, + temperature: float = 2.0, + ): + super().__init__(model) + if not hasattr(model, "forward_from_embedding"): + raise AssertionError( + "GIM requires models that implement `forward_from_embedding`." + ) + if not hasattr(model, "embedding_model"): + raise AssertionError( + "GIM requires access to the model's embedding_model." + ) + self.temperature = max(float(temperature), 1.0) + + def attribute( + self, + target_class_idx: Optional[int] = None, + **data, + ) -> Dict[str, torch.Tensor]: + """Compute GIM attributions for a StageNet batch.""" + device = next(self.model.parameters()).device + inputs, time_info, label_data = self._prepare_inputs(data, device) + embeddings, input_shapes = self._embed_inputs(inputs) + + # Clear stale gradients before the attribution pass. + self.model.zero_grad(set_to_none=True) + + # Step 1 (TSG): install the temperature-adjusted softmax hooks so all + # backward passes through StageNet's cumax operations use the higher τ. + with _GIMHookContext(self.model, self.temperature): + forward_kwargs = {**label_data} if label_data else {} + if time_info: + forward_kwargs["time_info"] = time_info + output = self.model.forward_from_embedding( + feature_embeddings=embeddings, + **forward_kwargs, + ) + + logits = output["logit"] + target = self._compute_target_output(logits, target_class_idx) + + # Step 2 (LayerNorm freeze): StageNet does not contain layer norms, so + # there are no σ parameters to freeze; the reset below ensures any + # hypothetical normalization buffers would stay constant as in Sec. 4.2. + self.model.zero_grad(set_to_none=True) + for emb in embeddings.values(): + if emb.grad is not None: + emb.grad.zero_() + + target.backward() + + attributions = {} + for key, emb in embeddings.items(): + grad = emb.grad + if grad is None: + grad = torch.zeros_like(emb) + # Step 3 (Gradient normalization): StageNet lacks the multi-input + # products targeted by the uniform rule, so dividing by 1 (identity) + # yields the same gradients the paper would propagate. + token_attr = self._collapse_to_input_shape(grad, input_shapes[key]) + attributions[key] = token_attr.detach() + + return attributions + + # --------------------------------------------------------------------- + # Helpers + # --------------------------------------------------------------------- + def _prepare_inputs( + self, + data: Dict[str, torch.Tensor], + device: torch.device, + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """Split raw data into value tensors, time tensors, and labels.""" + inputs: Dict[str, torch.Tensor] = {} + time_info: Dict[str, torch.Tensor] = {} + + for key in getattr(self.model, "feature_keys", []): + if key not in data: + continue + value = data[key] + time_tensor = None + if isinstance(value, tuple) and len(value) == 2: + time_tensor, value = value + time_tensor = self._to_tensor(time_tensor, device) + inputs[key] = self._to_tensor(value, device) + if time_tensor is not None: + time_info[key] = time_tensor + + label_data = {} + for label_key in getattr(self.model, "label_keys", []): + if label_key in data: + label_data[label_key] = self._to_tensor(data[label_key], device) + + return inputs, time_info, label_data + + def _embed_inputs( + self, + inputs: Dict[str, torch.Tensor], + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Size]]: + """Run the model's embedding stack and detach tensors for attribution.""" + embeddings: Dict[str, torch.Tensor] = {} + input_shapes: Dict[str, torch.Size] = {} + + for key, tensor in inputs.items(): + input_shapes[key] = tensor.shape + embedded = self.model.embedding_model({key: tensor}) + emb_tensor = embedded[key].detach() + emb_tensor.requires_grad_(True) + emb_tensor.retain_grad() + embeddings[key] = emb_tensor + + return embeddings, input_shapes + + def _compute_target_output( + self, + logits: torch.Tensor, + target_class_idx: Optional[int], + ) -> torch.Tensor: + """Select a scalar logit to backpropagate based on the target class.""" + if logits.dim() == 1: + logits = logits.unsqueeze(-1) + + if target_class_idx is None: + if logits.shape[-1] == 1: + selected = logits.squeeze(-1) + else: + indices = torch.argmax(logits, dim=-1) + selected = logits.gather(1, indices.unsqueeze(-1)).squeeze(-1) + else: + if isinstance(target_class_idx, torch.Tensor): + indices = target_class_idx.to(logits.device) + else: + indices = torch.full( + (logits.shape[0],), + int(target_class_idx), + device=logits.device, + dtype=torch.long, + ) + indices = indices.view(-1, 1) + if logits.shape[-1] == 1: + selected = logits.squeeze(-1) + else: + selected = logits.gather(1, indices).squeeze(-1) + + return selected.sum() + + def _collapse_to_input_shape( + self, + tensor: torch.Tensor, + orig_shape: torch.Size, + ) -> torch.Tensor: + """Sum the embedding dimension and reshape to match the raw inputs.""" + if tensor.dim() >= 2: + tensor = tensor.sum(dim=-1) + + if tensor.shape == orig_shape: + return tensor + + if len(orig_shape) > len(tensor.shape): + expanded = tensor + while len(expanded.shape) < len(orig_shape): + expanded = expanded.unsqueeze(-1) + expanded = expanded.expand(orig_shape) + return expanded + + try: + return tensor.reshape(orig_shape) + except RuntimeError: + return tensor + + @staticmethod + def _to_tensor(value, device: torch.device) -> torch.Tensor: + """Convert dataloader values (lists, numpy arrays) to tensors.""" + if isinstance(value, torch.Tensor): + return value.to(device) + return torch.as_tensor(value, device=device) diff --git a/pyhealth/models/transformer.py b/pyhealth/models/transformer.py index e40b3383a..d1e87abea 100644 --- a/pyhealth/models/transformer.py +++ b/pyhealth/models/transformer.py @@ -27,6 +27,21 @@ class Attention(nn.Module): """Scaled dot-product attention helper.""" + def __init__(self): + super().__init__() + self._activation_hooks = None + + def set_activation_hooks(self, hooks) -> None: + """Inject activation hooks for interpretability methods.""" + + self._activation_hooks = hooks + + def _apply_activation(self, name: str, tensor: torch.Tensor, **kwargs) -> torch.Tensor: + if self._activation_hooks is not None and hasattr(self._activation_hooks, "apply"): + return self._activation_hooks.apply(name, tensor, **kwargs) + fn = getattr(torch, name) + return fn(tensor, **kwargs) + def forward( self, query: torch.Tensor, @@ -52,12 +67,10 @@ def forward( Called inside :class:`MultiHeadedAttention` for each head. """ - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( - query.size(-1) - ) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) - p_attn = torch.softmax(scores, dim=-1) + p_attn = self._apply_activation("softmax", scores, dim=-1) if mask is not None: p_attn = p_attn.masked_fill(mask == 0, 0) if dropout is not None: @@ -102,6 +115,12 @@ def __repr__(self) -> str: f"dropout={self.dropout.p})" ) + def set_activation_hooks(self, hooks) -> None: + """Propagate activation hooks to the underlying Attention module.""" + + if hasattr(self.attention, "set_activation_hooks"): + self.attention.set_activation_hooks(hooks) + # helper functions for interpretability def get_attn_map(self) -> Optional[torch.Tensor]: """Return the last computed attention weights.""" @@ -235,6 +254,12 @@ def __init__(self, hidden, attn_heads, dropout): self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout) self.dropout = nn.Dropout(p=dropout) + def set_activation_hooks(self, hooks) -> None: + """Forward activation hooks to the multi-head attention block.""" + + if hasattr(self.attention, "set_activation_hooks"): + self.attention.set_activation_hooks(hooks) + def forward(self, x, mask=None, register_hook = False): """Forward propagation. @@ -281,6 +306,13 @@ def __init__(self, feature_size, heads=1, dropout=0.5, num_layers=1): [TransformerBlock(feature_size, heads, dropout) for _ in range(num_layers)] ) + def set_activation_hooks(self, hooks) -> None: + """Attach activation hooks to every TransformerBlock in the layer.""" + + for transformer in self.transformer: + if hasattr(transformer, "set_activation_hooks"): + transformer.set_activation_hooks(hooks) + def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, register_hook: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -391,6 +423,23 @@ def __init__( output_size = self.get_output_size() self.fc = nn.Linear(len(self.feature_keys) * embedding_dim, output_size) + self._activation_hooks = None + + def set_deeplift_hooks(self, hooks) -> None: + """Attach activation hooks for interpretability algorithms.""" + + self._activation_hooks = hooks + for layer in self.transformer.values(): + if hasattr(layer, "set_activation_hooks"): + layer.set_activation_hooks(hooks) + + def clear_deeplift_hooks(self) -> None: + """Remove previously registered interpretability hooks.""" + + self._activation_hooks = None + for layer in self.transformer.values(): + if hasattr(layer, "set_activation_hooks"): + layer.set_activation_hooks(None) @staticmethod def _split_temporal(feature): @@ -523,6 +572,47 @@ def _pool_embedding(x: torch.Tensor) -> torch.Tensor: x = x.unsqueeze(1) return x + @staticmethod + def _mask_from_embeddings(x: torch.Tensor) -> torch.Tensor: + """Infer a boolean mask directly from embedded representations.""" + + mask = torch.any(torch.abs(x) > 0, dim=-1) + if mask.dim() == 1: + mask = mask.unsqueeze(1) + invalid_rows = ~mask.any(dim=1) + if invalid_rows.any(): + mask[invalid_rows, 0] = True + return mask.bool() + + def forward_from_embedding( + self, + feature_embeddings: Dict[str, torch.Tensor], + time_info: Optional[Dict[str, torch.Tensor]] = None, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Forward pass that consumes pre-computed embeddings.""" + + register_hook = bool(kwargs.get("register_hook", False)) + patient_emb = [] + + for feature_key in self.feature_keys: + x = feature_embeddings[feature_key].to(self.device) + x = self._pool_embedding(x) + mask = self._mask_from_embeddings(x).to(self.device) + _, cls_emb = self.transformer[feature_key](x, mask, register_hook) + patient_emb.append(cls_emb) + + patient_emb = torch.cat(patient_emb, dim=1) + logits = self.fc(patient_emb) + + y_true = kwargs[self.label_key].to(self.device) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + results = {"loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits} + if kwargs.get("embed", False): + results["embed"] = patient_emb + return results + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation with PyHealth 2.0 inputs. diff --git a/tests/core/test_gim.py b/tests/core/test_gim.py new file mode 100644 index 000000000..aab946d09 --- /dev/null +++ b/tests/core/test_gim.py @@ -0,0 +1,230 @@ +import math +import unittest +from typing import Dict + +import torch +import torch.nn as nn + +from pyhealth.interpret.methods import GIM +from pyhealth.models import BaseModel + + +class _BinaryProcessor: + def size(self) -> int: + return 1 + + +class _DummyBinaryDataset: + """Minimal dataset stub that mimics the pieces BaseModel expects.""" + + def __init__(self): + self.input_schema = {"codes": "sequence"} + self.output_schema = {"label": "binary"} + self.output_processors = {"label": _BinaryProcessor()} + + +class _ToyEmbeddingModel(nn.Module): + """Deterministic embedding lookup ensuring reproducible gradients.""" + + def __init__(self, vocab_size: int = 32, embedding_dim: int = 4): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embedding_dim) + with torch.no_grad(): + weights = torch.arange(vocab_size * embedding_dim).float() + weights = weights.view(vocab_size, embedding_dim) + self.embedding.weight.copy_(weights / float(vocab_size)) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {key: self.embedding(val.long()) for key, val in inputs.items()} + + +class _ToyGIMModel(BaseModel): + """Small attention-style model exposing StageNet-compatible hooks.""" + + def __init__(self, vocab_size: int = 32, embedding_dim: int = 4): + super().__init__(dataset=_DummyBinaryDataset()) + self.feature_keys = ["codes"] + self.label_keys = ["label"] + self.mode = "binary" + + self.embedding_model = _ToyEmbeddingModel(vocab_size, embedding_dim) + self.query = nn.Linear(embedding_dim, embedding_dim, bias=False) + self.key = nn.Linear(embedding_dim, embedding_dim, bias=False) + self.value = nn.Linear(embedding_dim, embedding_dim, bias=False) + self.fc = nn.Linear(embedding_dim, 1, bias=False) + + self._activation_hooks = None + self.deeplift_hook_calls = 0 + + self._initialize_weights() + + def _initialize_weights(self) -> None: + with torch.no_grad(): + identity = torch.eye(self.query.in_features) + self.query.weight.copy_(identity) + self.key.weight.copy_(identity) + self.value.weight.copy_(identity) + self.fc.weight.copy_(torch.tensor([[0.2, -0.3, 0.4, 0.1]])) + + def set_deeplift_hooks(self, hooks) -> None: + self.deeplift_hook_calls += 1 + self._activation_hooks = hooks + + def clear_deeplift_hooks(self) -> None: + self._activation_hooks = None + + def _apply_activation(self, name: str, tensor: torch.Tensor, **kwargs) -> torch.Tensor: + if self._activation_hooks is not None and hasattr(self._activation_hooks, "apply"): + return self._activation_hooks.apply(name, tensor, **kwargs) + fn = getattr(torch, name) + return fn(tensor, **kwargs) + + def forward_from_embedding( + self, + feature_embeddings: Dict[str, torch.Tensor], + time_info: Dict[str, torch.Tensor] = None, + **kwargs, + ) -> Dict[str, torch.Tensor]: + emb = feature_embeddings["codes"] + q = self.query(emb) + k = self.key(emb) + v = self.value(emb) + + scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(q.size(-1)) + weights = self._apply_activation("softmax", scores, dim=-1) + context = torch.matmul(weights, v) + pooled = context.mean(dim=1) + logits = self.fc(pooled) + + label = kwargs.get("label") + if label is None: + label = torch.zeros_like(logits) + + return { + "logit": logits, + "y_prob": torch.sigmoid(logits), + "y_true": label, + "loss": torch.zeros((), device=logits.device), + } + + +class _ToyGIMModelWithCustomHooks(_ToyGIMModel): + """Variant that exposes dedicated set_gim_hooks/clear_gim_hooks.""" + + def __init__(self): + super().__init__() + self.gim_hook_calls = 0 + + def set_gim_hooks(self, hooks) -> None: + self.gim_hook_calls += 1 + self._activation_hooks = hooks + + def clear_gim_hooks(self) -> None: + self._activation_hooks = None + + +def _manual_token_attribution( + model: _ToyGIMModel, + tokens: torch.Tensor, + labels: torch.Tensor, +) -> torch.Tensor: + """Reference implementation mimicking GIM without temperature scaling.""" + + embeddings = model.embedding_model({"codes": tokens})["codes"].detach() + embeddings.requires_grad_(True) + embeddings.retain_grad() + + output = model.forward_from_embedding({"codes": embeddings}, label=labels) + logits = output["logit"].squeeze(-1) + target = logits.sum() + + model.zero_grad(set_to_none=True) + if embeddings.grad is not None: + embeddings.grad.zero_() + target.backward() + + grad = embeddings.grad.detach() + token_attr = grad.sum(dim=-1) + return token_attr + + +class TestGIM(unittest.TestCase): + """Unit tests validating the PyHealth GIM interpreter.""" + + def setUp(self): + torch.manual_seed(7) + self.tokens = torch.tensor([[2, 5, 3, 1]]) + self.labels = torch.zeros((1, 1)) + + def test_matches_manual_gradient_when_temperature_one(self): + """Temperature=1 should collapse to plain gradients.""" + + model = _ToyGIMModel() + gim = GIM(model, temperature=1.0) + + attributions = gim.attribute( + target_class_idx=0, + codes=self.tokens, + label=self.labels, + ) + manual = _manual_token_attribution(model, self.tokens, self.labels) + torch.testing.assert_close(attributions["codes"], manual, atol=1e-6, rtol=1e-5) + + def test_temperature_hooks_modify_gradients(self): + """Raising the temperature must both attach hooks and change attributions.""" + + model = _ToyGIMModel() + baseline_gim = GIM(model, temperature=1.0) + hot_gim = GIM(model, temperature=2.0) + + baseline_attr = baseline_gim.attribute( + target_class_idx=0, + codes=self.tokens, + label=self.labels, + )["codes"] + hot_attr = hot_gim.attribute( + target_class_idx=0, + codes=self.tokens, + label=self.labels, + )["codes"] + + self.assertEqual(model.deeplift_hook_calls, 1) + self.assertFalse(torch.allclose(baseline_attr, hot_attr)) + + def test_prefers_custom_gim_hooks(self): + """Models exposing set_gim_hooks should bypass the DeepLIFT surface.""" + + model = _ToyGIMModelWithCustomHooks() + gim = GIM(model, temperature=2.0) + gim.attribute(target_class_idx=0, codes=self.tokens, label=self.labels) + + self.assertEqual(model.gim_hook_calls, 1) + self.assertEqual(model.deeplift_hook_calls, 0) + + def test_attributions_match_input_shape(self): + """Collapsed gradients should align with the token tensor shape.""" + + model = _ToyGIMModel() + gim = GIM(model, temperature=1.0) + + attrs = gim.attribute(target_class_idx=0, codes=self.tokens, label=self.labels) + self.assertEqual(tuple(attrs["codes"].shape), tuple(self.tokens.shape)) + + def test_handles_temporal_tuple_inputs(self): + """StageNet-style (time, value) tuples should be processed seamlessly.""" + + model = _ToyGIMModel() + gim = GIM(model, temperature=1.0) + + time_indices = torch.arange(self.tokens.numel()).view_as(self.tokens).float() + attributions = gim.attribute( + target_class_idx=0, + codes=(time_indices, self.tokens), + label=self.labels, + ) + manual = _manual_token_attribution(model, self.tokens, self.labels) + torch.testing.assert_close(attributions["codes"], manual, atol=1e-6, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main()