Skip to content

Conversation

@Nimanui
Copy link
Contributor

@Nimanui Nimanui commented Nov 2, 2025

Had to refresh the branch due to some github shenanigans. Should be the same as the previous pull request, but I completed adjustments to let it render and manage batches instead of inputting a dataloader (I left the dataloader stuff if someone wants to do a larger set but perhaps it's redundant). I've also updated the example python notebook accordingly and refactored the tests. Also did some light refactoring to emphasize that this is just a basic pytorch gradient saliency visualization class.

Nimanui and others added 30 commits May 7, 2025 21:45
Applying and visualizing gradient saliency on a basic pyhealth model claissifying image data
…h dataloader and models

(cherry picked from commit 858c4f7)
Trying to find a version of this that github can render without errors
Still trying to get something that displays correctly in github
Remove temporary branch references from python notebook example import
@jhnwu3
Copy link
Collaborator

jhnwu3 commented Nov 4, 2025

Running unit tests, if it passes, I think it looks pretty good to me. We can always iterate further later. I think at some point, we might need a standardized API for the interpretability module haha since I see a lot of similarities in our approaches here.

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, sorry, I should've reviewed more in detail earlier. Was stuck in a meeting.

Some quick thoughts that I think would make more sense in terms of a refactor here again (sorry to make you do this, but I think it would reduce the complexity of the approach.

  1. We shouldn't store the dataloader inside of the class as it's really heavy to have an entire dataloader inside of an interpretability class here. The idea here is we simply call this class's "attribute" or .saliency_map(batch) as we're iterating across the dataloader. This way, technically the user can define whether or not to do it across the entire dataset themselves or decide how they want to do the sampling of interpretation.

  2. If possible, the .attribute() or .process_batch() would just return the tensor attributions themselves as I think that would be a pretty straightforward way of letting the user use these attribution features.

  3. The reason why we want to return the raw attribution tensor features is because a lot of downstream interpretability methods (like adversarial learning approaches) can often use these attributions in other creative ways.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also have it inherit from the new master's BaseInterpreterClass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could try but this already far more overengineered than the intent of just a basic method that just does a very simple gradient on input image data from a CNN. Would it be possible to come back to that on a later pass and then if this is helpful I can expand it to support StageNet at that point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also would like to try to use some minor image support with the LRP but if I reference any of this code or the example in it, it becomes dependent on this and a bit of a problem. I'm happy to come back to this later but I'm kind of stuck with this still being open.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, I think we can definitely simplify this class substantially. There are some things I'm a little confused about. Let me know your thoughts.

  1. We can remove the input_batch init argument here? Do we need an assertion check here?
  2. It feels like _process_batch and attribute() are redundant here?
  3. I also feel like _compute_saliency() isn't called anywhere? Or get_saliency_maps()?
  4. I really like the visualize function idea

Copy link
Collaborator

@jhnwu3 jhnwu3 Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. So the attribute(**data_batch) should be able to let users pass in any batch and return their attributions here. For reference, you can iterate through a pytorch dataloader and filter for a specific input. I think my concern is that it doesn't make sense for it to be in the init function? When other functions clearly can perform the attribution on any batch the user decides?

  2. (4) No worries, we can leave those be. Happy to have them as they are now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. At that point then, isn't it just a static function and class at that point then? I'm not sure what you would store from in the class if you don't even have the processing batch on init. I could make it optional for the existing application (which only calls attribute as a helper method), and I'll adjust it so it updates the batch if one is passed with the attribute method. I think that lines it up a bit closer.
  2. Turns out it slimmed up real nice, thank you for poking about this. I was able to clear about 100 lines and I'm pretty happy with the results.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The main thing I think (and feel free to disagree with me on this) is just making things standardized where we can assume that every (local) interpretability method should have an attribute() function that can be leveraged in our interpretability metrics here since the metrics assume you have some type of function attribute(), and it's nice to centralize the logic around this abstract interpreter class.

There's also things like not being able to compute gradients on discrete variables (see the DeepLift and IntegratedGradient examples) where you'll need some type of specialized logic here. I would imagine if these gradient-based saliency maps were computed on things like discrete word tokens here. It would require a little more complicated logic here, whether that's storing different types of baseline inputs to compute some difference in activations or just being able to have different modes (i.e SHAP can be computed a variety of different ways, but at the end of the day, they're all estimating the true SHAPley values)

But, honestly, I think object-oriented approaches are just a little cleaner in terms of abstraction and code readability (for me personally) Technically if we were to follow a pythonic style, we could pass around a bunch of functions everywhere, but imo that would be really hard to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, those changes are in as well, let me know if you find anything else. There is probably still a few minor refactoring pieces I can think of but I think that's all of the major ones. If there are additional examples we want to add I'd be inclined to do that later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the notebook is up to date with this code, although again I'm not sure if we want to do some more notebooks with different core models. I only picked this one as it was an easy fit, but I think we could add it to the chestxray 14 multiclassifier as well. There are also plenty of fun options for adding visualization tools, again thoughts for later work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also have it inherit from the new master's BaseInterpreterClass?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, I think we can definitely simplify this class substantially. There are some things I'm a little confused about. Let me know your thoughts.

  1. We can remove the input_batch init argument here? Do we need an assertion check here?
  2. It feels like _process_batch and attribute() are redundant here?
  3. I also feel like _compute_saliency() isn't called anywhere? Or get_saliency_maps()?
  4. I really like the visualize function idea

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh on the topic of documentation, can you add this file reference to: https://github.com/sunlabuiuc/PyHealth/blob/master/docs/api/interpret.rst

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of vibe-coding purposes, can we move these details into the docstrings in the class implementation? Also, it seems the docs are out of date. The nice part about having it specifically in the doc strings (that get rendered by autodocs) is that the LLM will update the doc strings and the logic every iteration so it'll be easy to keep things maintained.

I think the examples/notebook might be out of date too?

Otherwise, I think we're pretty much there. Thank you for taking the time to iterate on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The notebook is updated (I've been testing with it for every change in addtion to the unit tests), I'll poke the documentation and let you know, that probably is out of date. I don't mind moving it, the other documentation I just added is up to date.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rst is updated to resemble the other interpretability ones (sorry I didn't do that sooner) and I added the gist of what it had to the doc string in the code. Excellent call there and thank you for helping me make this much better.

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, we can iterate further on this later if need be.

@jhnwu3 jhnwu3 merged commit 8eceb3a into sunlabuiuc:master Nov 18, 2025
1 check passed
@Nimanui Nimanui deleted the SaliencyMappingClass branch November 18, 2025 23:14
dalloliogm pushed a commit to dalloliogm/PyHealth that referenced this pull request Nov 26, 2025
* Adding a basic gradient saliency method that functions on the pyhealth dataloader and models

* Updating function name for consistency and adding to the _init_.py

* Created using Colab

* Example of using gradient saliency

Applying and visualizing gradient saliency on a basic pyhealth model claissifying image data

* Delete Colab ChestXrayClassificationWithSaliency.ipynb

* Adding a basic gradient saliency method that functions on the pyhealth dataloader and models

(cherry picked from commit 858c4f7)

* Updating function name for consistency and adding to the _init_.py

(cherry picked from commit 3bd6860)

* Moving the saliency module to the "interpret\methods" folder

* Created using Colab

* Small bug fix and adjusting the ipynb example

* Add files via upload

* Created using Colab

* Locally run notebook example to remove collab dependency

* Add files via upload

* Add files via upload

Trying to find a version of this that github can render without errors

* Add files via upload

Still trying to get something that displays correctly in github

* Delete examples/ChestXrayClassificationWithSaliencyMapping.ipynb

* Delete ChestXrayClassificationWithSaliencyMapping.ipynb

* Update repository references in notebook

Remove temporary branch references from python notebook example import

* Adding saliency class and adjusting code to use the class

* Adding saliency class and adjusting code to use the class

* Refactoring methods in saliency.py

* Some refactoring and adding better batch handling

* Adding some documentation

* Basic Gradient imagery: removing dataload options

* Adding base_interpreter inheritance and testing

* Removing redundant code and the lazy initialization

* Making the input_batch optional and adding support to update it with attribute

* Adding a few unit tests

* Adding some documentation

* Documentation updates
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants