-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Aligns Softmax masking behavior with JAX for fully masked axis #21149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aligns Softmax masking behavior with JAX for fully masked axis #21149
Conversation
346a34c
to
ee7e979
Compare
ee7e979
to
13bfabc
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #21149 +/- ##
=======================================
Coverage 82.68% 82.69%
=======================================
Files 564 564
Lines 54220 54223 +3
Branches 8423 8424 +1
=======================================
+ Hits 44832 44837 +5
+ Misses 7311 7310 -1
+ Partials 2077 2076 -1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the PR!
Fixes #21123
This change modifies the
keras.layers.activations.Softmax
layer to ensure consistent behavior when masking is applied, particularly for the edge case where all elements along the softmax axis are masked.Issue
[0.25, 0.25, 0.25, 0.25]
) when all elements along the softmax axis were masked. This behavior differs from the semantics ofjax.nn.softmax(where=mask)
, which outputs zeros for all elements in a fully masked slice.Fix
The softmax output tensor is multiplied element-wise by the original boolean mask, ensuring any masked softmax output position is zero'd out.