-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Add batch support for saliency maps #8573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Signed-off-by: einsyang723 <[email protected]>
Walkthrough
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
monai/visualize/class_activation_maps.py (1)
141-147
: Ensure class_idx tensor is Long on correct device; fix attribute type.Move and cast once to the logits device and Long dtype.
- elif isinstance(class_idx, torch.Tensor): - self.class_idx = class_idx.to(logits.device) + elif isinstance(class_idx, torch.Tensor): + self.class_idx = class_idx.to(device=logits.device, dtype=torch.long) else: self.class_idx = class_idxOutside this hunk, update the attribute annotation to reflect tensor support:
# near Line 76 self.class_idx: int | torch.Tensor | None = None # previously: int | None
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (3)
monai/visualize/class_activation_maps.py
(1 hunks)monai/visualize/gradient_based.py
(0 hunks)tests/integration/test_vis_gradbased.py
(1 hunks)
💤 Files with no reviewable changes (1)
- monai/visualize/gradient_based.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/integration/test_vis_gradbased.py
monai/visualize/class_activation_maps.py
🪛 Ruff (0.12.2)
monai/visualize/class_activation_maps.py
135-135: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: packaging
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: build-docs
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (2)
tests/integration/test_vis_gradbased.py (1)
41-41
: Expanded batch shapes: good coverage.Nice additions; they exercise 2D/3D and multi‑channel cases.
Consider adding a small set of cases that pass class_idx explicitly to validate the new per‑batch behavior and device/dtype handling:
- class_idx as int
- class_idx as 0‑D tensor on CPU while model runs on CUDA
- class_idx as per‑batch LongTensor of length B
I can draft a test method if desired.
Also applies to: 44-44, 47-47, 50-50, 53-53
monai/visualize/class_activation_maps.py (1)
151-152
: Using sum() before backward is correct for batched scores.This keeps the graph simple and supports per‑sample indices. LGTM.
def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor: | ||
if isinstance(class_idx, int): | ||
return logits[:, class_idx].squeeze() | ||
elif class_idx.numel() == 1: | ||
return logits[:, class_idx.item()] | ||
elif len(class_idx.view(-1)) == logits.shape[0]: | ||
return torch.gather(logits, 1, class_idx.unsqueeze(1)).squeeze(1) | ||
else: | ||
raise ValueError("expect length of class_idx equal to batch size") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ Verification inconclusive
Batch indexing: enforce dtype/device; avoid squeeze; tighten validation.
Current behavior can return a 0‑D tensor for batch==1 and doesn’t enforce Long dtype for gather. Recommend unifying outputs to shape [B] and validating lengths.
Apply within this hunk:
- def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
- if isinstance(class_idx, int):
- return logits[:, class_idx].squeeze()
- elif class_idx.numel() == 1:
- return logits[:, class_idx.item()]
- elif len(class_idx.view(-1)) == logits.shape[0]:
- return torch.gather(logits, 1, class_idx.unsqueeze(1)).squeeze(1)
- else:
- raise ValueError("expect length of class_idx equal to batch size")
+ def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
+ if isinstance(class_idx, int):
+ # keep shape [B], avoid scalar squeeze when B==1
+ return logits[:, class_idx]
+ idx = class_idx.to(device=logits.device, dtype=torch.long).view(-1)
+ if idx.numel() == 1:
+ return logits[:, idx.item()]
+ if idx.numel() == logits.shape[0]:
+ return torch.gather(logits, 1, idx.unsqueeze(1)).squeeze(1)
+ raise ValueError("class_idx must be scalar or have length equal to batch size")
Also, GradCAM++ multiplies self.nn_module.score.exp()
by grad
. With batched scores this will broadcast incorrectly unless reshaped. See suggestion under verification below.
To prevent a broadcast error in GradCAM++ for B>1, reshape score before multiplication:
# In GradCAMpp.compute_map(...) before multiplying with grad:
score = cast(torch.Tensor, self.nn_module.score).exp().view(b, *([1] * (grad.dim() - 1)))
relu_grad = F.relu(score * grad)
Please run the GradCAM++ tests with B>1 to confirm no shape/broadcast errors.
Batch indexing: enforce dtype/device; avoid squeeze; tighten validation.
File: monai/visualize/class_activation_maps.py Lines: 127-135
Keep outputs shaped [B], coerce indices to torch.long on logits.device, and validate lengths.
Apply within this hunk:
- def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
- if isinstance(class_idx, int):
- return logits[:, class_idx].squeeze()
- elif class_idx.numel() == 1:
- return logits[:, class_idx.item()]
- elif len(class_idx.view(-1)) == logits.shape[0]:
- return torch.gather(logits, 1, class_idx.unsqueeze(1)).squeeze(1)
- else:
- raise ValueError("expect length of class_idx equal to batch size")
+ def class_score(self, logits: torch.Tensor, class_idx: int | torch.Tensor) -> torch.Tensor:
+ if isinstance(class_idx, int):
+ # keep shape [B], avoid scalar squeeze when B==1
+ return logits[:, class_idx]
+ idx = class_idx.to(device=logits.device, dtype=torch.long).view(-1)
+ if idx.numel() == 1:
+ return logits[:, idx.item()]
+ if idx.numel() == logits.shape[0]:
+ return torch.gather(logits, 1, idx.unsqueeze(1)).squeeze(1)
+ raise ValueError("class_idx must be scalar or have length equal to batch size")
Also fix GradCAM++ broadcasting (reshape score to match grad dims), e.g.:
# in GradCAMpp.compute_map(...) before multiplying with grad:
score = cast(torch.Tensor, self.nn_module.score).exp().view(b, *([1] * (grad.dim() - 1)))
relu_grad = F.relu(score * grad)
🧰 Tools
🪛 Ruff (0.12.2)
135-135: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In monai/visualize/class_activation_maps.py around lines 127 to 135, the
class_score method must return a tensor shaped [B], validate that class_idx
length equals batch size, and ensure class_idx is a torch.long tensor on the
same device as logits instead of using Python int checks and squeeze calls;
coerce class_idx via class_idx = torch.as_tensor(class_idx,
device=logits.device, dtype=torch.long), check class_idx.numel() ==
logits.shape[0] and then use torch.gather(logits, 1,
class_idx.view(-1,1)).view(-1) for a stable [B] output, raising a ValueError
otherwise. Also in GradCAM++ compute_map, reshape the scalar score to broadcast
across grad dims before multiplying: set score =
self.nn_module.score.exp().view(b, *([1] * (grad.dim() - 1))) and then use
relu_grad = F.relu(score * grad) so shapes align for subsequent reductions.
Fixes #7334 .
Description
This PR adds batch processing support for saliency methods (e.g., VanillaGrad) in MONAI.
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.