-
Notifications
You must be signed in to change notification settings - Fork 545
Basic gradient saliency visualization #594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 33 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
858c4f7
Adding a basic gradient saliency method that functions on the pyhealt…
Nimanui 3bd6860
Updating function name for consistency and adding to the _init_.py
Nimanui ce1e510
Created using Colab
Nimanui 8eb57eb
Example of using gradient saliency
Nimanui b376b8a
Delete Colab ChestXrayClassificationWithSaliency.ipynb
Nimanui 0e14d97
Merge branch 'sunlabuiuc:master' into master
Nimanui 57e2c3c
Adding a basic gradient saliency method that functions on the pyhealt…
Nimanui 997db4a
Updating function name for consistency and adding to the _init_.py
Nimanui 1c13529
Merge branch 'sunlabuiuc:master' into master
Nimanui 5c88fbc
Merge branch 'master' into SaliencyMappingGradient
Nimanui 9efd00e
Merge branch 'sunlabuiuc:master' into master
Nimanui 6aeeb69
Merge branch 'master' into SaliencyMappingGradient
Nimanui f3f905e
Moving the saliency module to the "interpret\methods" folder
Nimanui 15d62db
Created using Colab
Nimanui f98ecf8
Small bug fix and adjusting the ipynb example
Nimanui 83d9412
Add files via upload
Nimanui 470f54a
Created using Colab
Nimanui 1f1628c
Locally run notebook example to remove collab dependency
Nimanui 28c48fc
Add files via upload
Nimanui f6487f9
Add files via upload
Nimanui 3a97db7
Add files via upload
Nimanui 64a940a
Delete examples/ChestXrayClassificationWithSaliencyMapping.ipynb
Nimanui 0b45143
Merge branch 'sunlabuiuc:master' into SaliencyMappingGradient
Nimanui 3dbaa83
Delete ChestXrayClassificationWithSaliencyMapping.ipynb
Nimanui d3c320d
Update repository references in notebook
Nimanui e9d9d9e
Adding saliency class and adjusting code to use the class
Nimanui 5e5fcd6
Merge pyhealth main
Nimanui 1c1e912
Adding saliency class and adjusting code to use the class
Nimanui 862b1c4
Refactoring methods in saliency.py
Nimanui 1a5b8c9
Some refactoring and adding better batch handling
Nimanui 47b2c39
Adding some documentation
Nimanui 2d8e6d0
Merge remote-tracking branch 'upstream/master' into SaliencyMappingClass
Nimanui e159318
Basic Gradient imagery: removing dataload options
Nimanui 8041c09
Merge remote-tracking branch 'upstream/master' into SaliencyMappingClass
Nimanui 1240716
Adding base_interpreter inheritance and testing
Nimanui 755289b
Removing redundant code and the lazy initialization
Nimanui e051d30
Making the input_batch optional and adding support to update it with …
Nimanui e2b65e2
Adding a few unit tests
Nimanui b9accc0
Adding some documentation
Nimanui 26ebc71
Documentation updates
Nimanui File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
122 changes: 122 additions & 0 deletions
122
docs/api/interpret/pyhealth.interpret.methods.basic_gradient.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| pyhealth.interpret.methods.basic_gradient | ||
| =========================================== | ||
|
|
||
| Overview | ||
| -------- | ||
|
|
||
| The ``BasicGradientSaliencyMaps`` class provides gradient-based saliency map visualization for | ||
| PyHealth's image classification models. This interpretability method helps identify which regions | ||
| of medical images most influenced the model's prediction by computing gradients of model outputs | ||
| with respect to input pixels. | ||
|
|
||
| This method is particularly useful for: | ||
|
|
||
| - **Clinical interpretability**: Understanding which image regions drove a particular diagnosis | ||
| - **Model debugging**: Identifying if the model focuses on clinically relevant features | ||
| - **Trust and transparency**: Providing visual explanations for model predictions | ||
| - **Error analysis**: Comparing saliency maps for correct vs. incorrect predictions | ||
|
|
||
| The implementation computes saliency by taking the maximum absolute gradient across color channels | ||
| for each pixel, highlighting the most influential regions in the input image. | ||
|
|
||
| Key Features | ||
| ------------ | ||
|
|
||
| - **Dual input support**: Process batches from DataLoader or direct batch inputs | ||
| - **Flexible visualization**: Matplotlib overlay with configurable transparency | ||
| - **Label comparison**: Display both true labels and model predictions | ||
| - **Efficient storage**: Separate caching for different data sources | ||
|
|
||
| Usage Notes | ||
| ----------- | ||
|
|
||
| 1. **Gradients required**: Do not use within ``torch.no_grad()`` context | ||
| 2. **Model compatibility**: Works with PyHealth image classification models | ||
| 3. **Memory usage**: Limit batch count to control memory consumption | ||
| 4. **Batch visualization**: Use ``batch_index`` for pre-computed maps, omit for on-the-fly computation | ||
|
|
||
| Quick Start | ||
| ----------- | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| from pyhealth.interpret.methods.basic_gradient import BasicGradientSaliencyMaps | ||
| from pyhealth.datasets import get_dataloader | ||
| import matplotlib.pyplot as plt | ||
|
|
||
| # Assume you have a trained image model and dataset | ||
| model = TorchvisionModel(dataset=sample_dataset, ...) | ||
| # ... train the model ... | ||
|
|
||
| # Create interpretability object with dataloader | ||
| dataloader = get_dataloader(dataset, batch_size=32, shuffle=True) | ||
| saliency_maps = BasicGradientSaliencyMaps( | ||
| model=model, | ||
| dataloader=dataloader, | ||
| batches=3 | ||
| ) | ||
| saliency_maps.init_gradient_saliency_maps() | ||
|
|
||
| # Visualize from pre-computed maps | ||
| saliency_maps.visualize_saliency_map( | ||
| plt, | ||
| image_index=10, | ||
| batch_index=0, | ||
| title="Gradient Saliency", | ||
| id2label={0: "Normal", 1: "COVID", 2: "Pneumonia"}, | ||
| alpha=0.6 | ||
| ) | ||
|
|
||
| Custom Batch Example | ||
| -------------------- | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import torch | ||
|
|
||
| # Create a custom batch (e.g., filter by diagnosis) | ||
| covid_samples = [s for s in dataset.samples if s['disease'].item() == covid_label] | ||
| covid_batch = { | ||
| 'image': torch.stack([covid_samples[i]['image'] for i in range(32)]), | ||
| 'disease': torch.stack([covid_samples[i]['disease'] for i in range(32)]) | ||
| } | ||
|
|
||
| # Initialize with custom batch | ||
| saliency_maps = BasicGradientSaliencyMaps(model=model, input_batch=covid_batch) | ||
| saliency_maps.init_gradient_saliency_maps() | ||
|
|
||
| # Visualize (no batch_index means use input_batch) | ||
| saliency_maps.visualize_saliency_map( | ||
| plt, | ||
| image_index=0, | ||
| title="COVID Sample", | ||
| id2label=id2label, | ||
| alpha=0.6 | ||
| ) | ||
|
|
||
| API Reference | ||
| ------------- | ||
|
|
||
| .. autoclass:: pyhealth.interpret.methods.basic_gradient.BasicGradientSaliencyMaps | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: | ||
| :member-order: bysource | ||
|
|
||
| Algorithm Details | ||
| ----------------- | ||
|
|
||
| The saliency computation follows these steps: | ||
|
|
||
| 1. **Forward pass**: Compute model predictions for the input batch | ||
| 2. **Target selection**: Use predicted class (argmax of probabilities) | ||
| 3. **Backward pass**: Compute gradients with respect to input pixels | ||
| 4. **Saliency map**: Take absolute value and max across color channels | ||
|
|
||
| Mathematical formula: | ||
|
|
||
| .. math:: | ||
|
|
||
| \text{saliency}(x, y) = \max_{c} \left| \frac{\partial \text{score}_{\text{predicted}}}{\partial \text{pixel}_{x,y,c}} \right| | ||
|
|
||
| where :math:`c` iterates over color channels (RGB or grayscale). |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the sake of vibe-coding purposes, can we move these details into the docstrings in the class implementation? Also, it seems the docs are out of date. The nice part about having it specifically in the doc strings (that get rendered by autodocs) is that the LLM will update the doc strings and the logic every iteration so it'll be easy to keep things maintained.
I think the examples/notebook might be out of date too?
Otherwise, I think we're pretty much there. Thank you for taking the time to iterate on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The notebook is updated (I've been testing with it for every change in addtion to the unit tests), I'll poke the documentation and let you know, that probably is out of date. I don't mind moving it, the other documentation I just added is up to date.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rst is updated to resemble the other interpretability ones (sorry I didn't do that sooner) and I added the gist of what it had to the doc string in the code. Excellent call there and thank you for helping me make this much better.