Skip to content

Conversation

einsyang723
Copy link

Fixes #7334 .

Description

This PR adds batch processing support for saliency methods (e.g., VanillaGrad) in MONAI.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.

Copy link
Contributor

coderabbitai bot commented Sep 16, 2025

Walkthrough

  • monai/visualize/class_activation_maps.py: Extended ModelWithHooks.class_score to accept int or torch.Tensor, including per-batch class index selection via torch.gather. Updated call to handle Tensor class_idx (moving to logits device) and to pass class_idx directly to class_score without int casting. Forward/backward hook flow unchanged.
  • monai/visualize/gradient_based.py: Removed batch-size==1 validation in VanillaGrad.get_grad, enabling gradients for batched inputs; computation path unchanged otherwise.
  • tests/integration/test_vis_gradbased.py: Added parameterized cases with larger/batched input shapes across multiple model types to cover batch processing paths.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Out of Scope Changes Check ❓ Inconclusive Most changes are directly related to enabling batch saliency support (gradient_based.py, ModelWithHooks, and expanded tests). However, the raw summary contains a contradictory note about ModelWithHooks.class_score's declaration (the narrative says it was extended to accept int Tensor but the "Alterations" line indicates a signature of class_idx: int), creating ambiguity about a possible unintended public API change. Because of this inconsistency in the provided summary, a conclusive determination about any breaking or out-of-scope public API modification cannot be made.
✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The title "Add batch support for saliency maps" is concise, focused, and accurately summarizes the primary intent of the PR—enabling batch processing for saliency methods (e.g., removing single-sample enforcement and updating scoring behavior). It is clear and free of noise, so teammates scanning history will understand the main change.
Linked Issues Check ✅ Passed The changes address issue #7334 by removing the single-sample restriction in VanillaGrad.get_grad, enabling per-sample class_idx handling in ModelWithHooks, and adding parameterized tests that exercise multiple batch shapes and gradient-based methods. These modifications implement the requested batch-processing capability for saliency map generation across multiple input samples. Based on the provided summaries, the implementation aligns with the linked issue objectives.
Description Check ✅ Passed The PR description follows the repository template: it includes "Fixes #7334", a short Description of the change (adding batch processing support for saliency methods), and the "Types of changes" checklist with tests and local runs marked as passed. The author also states integration and quick tests passed locally. The description is sufficient for review but could be slightly improved with a brief implementation note or list of modified files.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
monai/visualize/class_activation_maps.py (1)

141-147: Ensure class_idx tensor is Long on correct device; fix attribute type.

Move and cast once to the logits device and Long dtype.

-        elif isinstance(class_idx, torch.Tensor):
-            self.class_idx = class_idx.to(logits.device)
+        elif isinstance(class_idx, torch.Tensor):
+            self.class_idx = class_idx.to(device=logits.device, dtype=torch.long)
         else:
             self.class_idx = class_idx

Outside this hunk, update the attribute annotation to reflect tensor support:

# near Line 76
self.class_idx: int | torch.Tensor | None = None  # previously: int | None
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 401ea4a and 82f52c5.

📒 Files selected for processing (3)
  • monai/visualize/class_activation_maps.py (1 hunks)
  • monai/visualize/gradient_based.py (0 hunks)
  • tests/integration/test_vis_gradbased.py (1 hunks)
💤 Files with no reviewable changes (1)
  • monai/visualize/gradient_based.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • tests/integration/test_vis_gradbased.py
  • monai/visualize/class_activation_maps.py
🪛 Ruff (0.12.2)
monai/visualize/class_activation_maps.py

135-135: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: build-docs
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (2)
tests/integration/test_vis_gradbased.py (1)

41-41: Expanded batch shapes: good coverage.

Nice additions; they exercise 2D/3D and multi‑channel cases.

Consider adding a small set of cases that pass class_idx explicitly to validate the new per‑batch behavior and device/dtype handling:

  • class_idx as int
  • class_idx as 0‑D tensor on CPU while model runs on CUDA
  • class_idx as per‑batch LongTensor of length B

I can draft a test method if desired.

Also applies to: 44-44, 47-47, 50-50, 53-53

monai/visualize/class_activation_maps.py (1)

151-152: Using sum() before backward is correct for batched scores.

This keeps the graph simple and supports per‑sample indices. LGTM.

Comment on lines +127 to +135
def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
if isinstance(class_idx, int):
return logits[:, class_idx].squeeze()
elif class_idx.numel() == 1:
return logits[:, class_idx.item()]
elif len(class_idx.view(-1)) == logits.shape[0]:
return torch.gather(logits, 1, class_idx.unsqueeze(1)).squeeze(1)
else:
raise ValueError("expect length of class_idx equal to batch size")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

❓ Verification inconclusive

Batch indexing: enforce dtype/device; avoid squeeze; tighten validation.

Current behavior can return a 0‑D tensor for batch==1 and doesn’t enforce Long dtype for gather. Recommend unifying outputs to shape [B] and validating lengths.

Apply within this hunk:

-    def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
-        if isinstance(class_idx, int):
-            return logits[:, class_idx].squeeze()
-        elif class_idx.numel() == 1:
-            return logits[:, class_idx.item()]
-        elif len(class_idx.view(-1)) == logits.shape[0]:
-            return torch.gather(logits, 1, class_idx.unsqueeze(1)).squeeze(1)
-        else:
-            raise ValueError("expect length of class_idx equal to batch size")
+    def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
+        if isinstance(class_idx, int):
+            # keep shape [B], avoid scalar squeeze when B==1
+            return logits[:, class_idx]
+        idx = class_idx.to(device=logits.device, dtype=torch.long).view(-1)
+        if idx.numel() == 1:
+            return logits[:, idx.item()]
+        if idx.numel() == logits.shape[0]:
+            return torch.gather(logits, 1, idx.unsqueeze(1)).squeeze(1)
+        raise ValueError("class_idx must be scalar or have length equal to batch size")

Also, GradCAM++ multiplies self.nn_module.score.exp() by grad. With batched scores this will broadcast incorrectly unless reshaped. See suggestion under verification below.

To prevent a broadcast error in GradCAM++ for B>1, reshape score before multiplication:

# In GradCAMpp.compute_map(...) before multiplying with grad:
score = cast(torch.Tensor, self.nn_module.score).exp().view(b, *([1] * (grad.dim() - 1)))
relu_grad = F.relu(score * grad)

Please run the GradCAM++ tests with B>1 to confirm no shape/broadcast errors.


Batch indexing: enforce dtype/device; avoid squeeze; tighten validation.

File: monai/visualize/class_activation_maps.py Lines: 127-135

Keep outputs shaped [B], coerce indices to torch.long on logits.device, and validate lengths.

Apply within this hunk:

-    def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
-        if isinstance(class_idx, int):
-            return logits[:, class_idx].squeeze()
-        elif class_idx.numel() == 1:
-            return logits[:, class_idx.item()]
-        elif len(class_idx.view(-1)) == logits.shape[0]:
-            return torch.gather(logits, 1, class_idx.unsqueeze(1)).squeeze(1)
-        else:
-            raise ValueError("expect length of class_idx equal to batch size")
+    def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
+        if isinstance(class_idx, int):
+            # keep shape [B], avoid scalar squeeze when B==1
+            return logits[:, class_idx]
+        idx = class_idx.to(device=logits.device, dtype=torch.long).view(-1)
+        if idx.numel() == 1:
+            return logits[:, idx.item()]
+        if idx.numel() == logits.shape[0]:
+            return torch.gather(logits, 1, idx.unsqueeze(1)).squeeze(1)
+        raise ValueError("class_idx must be scalar or have length equal to batch size")

Also fix GradCAM++ broadcasting (reshape score to match grad dims), e.g.:

# in GradCAMpp.compute_map(...) before multiplying with grad:
score = cast(torch.Tensor, self.nn_module.score).exp().view(b, *([1] * (grad.dim() - 1)))
relu_grad = F.relu(score * grad)
🧰 Tools
🪛 Ruff (0.12.2)

135-135: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In monai/visualize/class_activation_maps.py around lines 127 to 135, the
class_score method must return a tensor shaped [B], validate that class_idx
length equals batch size, and ensure class_idx is a torch.long tensor on the
same device as logits instead of using Python int checks and squeeze calls;
coerce class_idx via class_idx = torch.as_tensor(class_idx,
device=logits.device, dtype=torch.long), check class_idx.numel() ==
logits.shape[0] and then use torch.gather(logits, 1,
class_idx.view(-1,1)).view(-1) for a stable [B] output, raising a ValueError
otherwise. Also in GradCAM++ compute_map, reshape the scalar score to broadcast
across grad dims before multiplying: set score =
self.nn_module.score.exp().view(b, *([1] * (grad.dim() - 1))) and then use
relu_grad = F.relu(score * grad) so shapes align for subsequent reductions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Batch support for saliency maps
1 participant