Skip to content

Conversation

rajveer43
Copy link
Contributor

@rajveer43 rajveer43 commented Sep 3, 2025

What does this PR fix?

This PR fixes the cluster loss (L_c) calculation in DMoNPooling. Previously, the implementation incorrectly normalized the loss using a wrongly shaped tensor derived from the node mask (mask.sum(dim=1) where mask had shape B × N × 1). This led to unintended broadcasting and incorrect values for the cluster loss.

Changes made

  • Introduced a proper 2D node mask (node_mask_2d) to compute valid node counts per graph.
  • Calculated the per-graph node count vector n_per_graph with shape B, ensuring correct normalization.
  • Corrected the cluster loss formula to:
    cluster_size = torch.einsum('ijk->ik', s)  # B x C
    cluster_norm = torch.norm(cluster_size, dim=1)  # B
    n_per_graph = node_mask_2d.sum(dim=1).to(cluster_norm.dtype)  # B
    cluster_loss = (cluster_norm / n_per_graph) * torch.norm(i_s) - 1
    cluster_loss = cluster_loss.mean()
  • Updated inline comments for clarity.

Impact

  • Corrects cluster loss computation for batched graphs.
  • Prevents silent broadcasting errors.
  • Brings implementation closer to the original paper definition of L_c.

Closes: #10148

Copy link

codecov bot commented Sep 3, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 85.97%. Comparing base (c211214) to head (677b7ec).
⚠️ Report is 98 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #10429      +/-   ##
==========================================
- Coverage   86.11%   85.97%   -0.14%     
==========================================
  Files         496      502       +6     
  Lines       33655    35210    +1555     
==========================================
+ Hits        28981    30272    +1291     
- Misses       4674     4938     +264     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant