Skip to content

Aligns Softmax masking behavior with JAX for fully masked axis #21149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 11, 2025

Conversation

JyotinderSingh
Copy link
Collaborator

@JyotinderSingh JyotinderSingh commented Apr 10, 2025

Fixes #21123

This change modifies the keras.layers.activations.Softmax layer to ensure consistent behavior when masking is applied, particularly for the edge case where all elements along the softmax axis are masked.

Issue

  1. The masking mechanism worked by adding a large negative number to the logits corresponding to masked positions before applying the softmax activation.
  2. While this effectively minimized the contribution of masked elements when some elements were unmasked, it led to a non-zero, uniform distribution (e.g., [0.25, 0.25, 0.25, 0.25]) when all elements along the softmax axis were masked. This behavior differs from the semantics of jax.nn.softmax(where=mask), which outputs zeros for all elements in a fully masked slice.

Fix

The softmax output tensor is multiplied element-wise by the original boolean mask, ensuring any masked softmax output position is zero'd out.

if mask is not None:
    # Apply the mask to the softmax output to ensure that masked
    # values are set to 0 in case the entire axis is masked.
    outputs = outputs * backend.cast(mask, outputs.dtype)

@JyotinderSingh JyotinderSingh changed the title Fixes softmax masking logic to match JAX behavior Aligns Softmax masking behavior with JAX for fully masked axis Apr 10, 2025
@codecov-commenter
Copy link

codecov-commenter commented Apr 10, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.69%. Comparing base (9f8bbca) to head (051d18e).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21149   +/-   ##
=======================================
  Coverage   82.68%   82.69%           
=======================================
  Files         564      564           
  Lines       54220    54223    +3     
  Branches     8423     8424    +1     
=======================================
+ Hits        44832    44837    +5     
+ Misses       7311     7310    -1     
+ Partials     2077     2076    -1     
Flag Coverage Δ
keras 82.50% <100.00%> (+<0.01%) ⬆️
keras-jax 63.92% <100.00%> (+<0.01%) ⬆️
keras-numpy 59.03% <100.00%> (+<0.01%) ⬆️
keras-openvino 32.98% <83.33%> (+<0.01%) ⬆️
keras-tensorflow 64.30% <100.00%> (+<0.01%) ⬆️
keras-torch 63.99% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@JyotinderSingh JyotinderSingh requested a review from fchollet April 10, 2025 05:43
@JyotinderSingh JyotinderSingh requested review from hertschuh and removed request for fchollet April 10, 2025 16:17
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the PR!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Apr 10, 2025
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Apr 11, 2025
@fchollet fchollet merged commit 2111fbc into keras-team:master Apr 11, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Softmax layer diverges from jax.nn.softmax
6 participants