Skip to content

Commit 2111fbc

Browse files
authoredApr 11, 2025··
Aligns Softmax masking behavior with JAX for fully masked axis (#21149)
* Fixes softmax masking logic to match JAX behavior * fix comment * use backend.numpy.multipy for element-wise multiplication
1 parent 9f8bbca commit 2111fbc

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed
 

‎keras/src/layers/activations/softmax.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,25 @@ def call(self, inputs, mask=None):
5858
inputs += adder
5959
if isinstance(self.axis, (tuple, list)):
6060
if len(self.axis) > 1:
61-
return backend.numpy.exp(
61+
outputs = backend.numpy.exp(
6262
inputs
6363
- backend.math.logsumexp(
6464
inputs, axis=self.axis, keepdims=True
6565
)
6666
)
6767
else:
68-
return activations.softmax(inputs, axis=self.axis[0])
69-
return activations.softmax(inputs, axis=self.axis)
68+
outputs = activations.softmax(inputs, axis=self.axis[0])
69+
else:
70+
outputs = activations.softmax(inputs, axis=self.axis)
71+
72+
if mask is not None:
73+
# Apply the mask to the softmax output to ensure that masked
74+
# values are set to 0 in case the entire axis is masked.
75+
outputs = backend.numpy.multiply(
76+
outputs, backend.cast(mask, outputs.dtype)
77+
)
78+
79+
return outputs
7080

7181
def get_config(self):
7282
config = super().get_config()

‎keras/src/layers/activations/softmax_test.py

+37
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,40 @@ def test_softmax_correctness_with_axis(self):
4949
)
5050
result = softmax_layer(input)
5151
self.assertAllClose(result, expected_output)
52+
53+
def test_softmax_masked_values_are_zero_including_fully_masked(self):
54+
"""
55+
Tests softmax with mask on default axis (-1).
56+
Ensures output is 0 where mask is False.
57+
Includes a row where all elements are masked.
58+
"""
59+
softmax_layer = softmax.Softmax() # Default axis = -1
60+
61+
input = np.array(
62+
[
63+
[1.0, 2.0, 5.0, 1.0],
64+
[1.0, 1.0, 1.0, 1.0],
65+
[3.0, 1.0, 2.0, 4.0],
66+
],
67+
dtype=np.float32,
68+
)
69+
mask = np.array(
70+
[
71+
[True, True, False, False], # Partially masked
72+
[False, False, False, False], # Fully masked
73+
[True, True, True, True], # Not masked
74+
],
75+
dtype=bool,
76+
)
77+
78+
expected_output = np.array(
79+
[
80+
[0.268941, 0.731059, 0.0, 0.0], # last two masked
81+
[0.0, 0.0, 0.0, 0.0], # Fully masked row should be all zeros
82+
[0.236883, 0.032059, 0.087144, 0.643914],
83+
]
84+
)
85+
86+
result = softmax_layer(input, mask=mask)
87+
88+
self.assertAllClose(result, expected_output)

0 commit comments

Comments
 (0)
Please sign in to comment.