@@ -49,3 +49,40 @@ def test_softmax_correctness_with_axis(self):
49
49
)
50
50
result = softmax_layer (input )
51
51
self .assertAllClose (result , expected_output )
52
+
53
+ def test_softmax_masked_values_are_zero_including_fully_masked (self ):
54
+ """
55
+ Tests softmax with mask on default axis (-1).
56
+ Ensures output is 0 where mask is False.
57
+ Includes a row where all elements are masked.
58
+ """
59
+ softmax_layer = softmax .Softmax () # Default axis = -1
60
+
61
+ input = np .array (
62
+ [
63
+ [1.0 , 2.0 , 5.0 , 1.0 ],
64
+ [1.0 , 1.0 , 1.0 , 1.0 ],
65
+ [3.0 , 1.0 , 2.0 , 4.0 ],
66
+ ],
67
+ dtype = np .float32 ,
68
+ )
69
+ mask = np .array (
70
+ [
71
+ [True , True , False , False ], # Partially masked
72
+ [False , False , False , False ], # Fully masked
73
+ [True , True , True , True ], # Not masked
74
+ ],
75
+ dtype = bool ,
76
+ )
77
+
78
+ expected_output = np .array (
79
+ [
80
+ [0.268941 , 0.731059 , 0.0 , 0.0 ], # last two masked
81
+ [0.0 , 0.0 , 0.0 , 0.0 ], # Fully masked row should be all zeros
82
+ [0.236883 , 0.032059 , 0.087144 , 0.643914 ],
83
+ ]
84
+ )
85
+
86
+ result = softmax_layer (input , mask = mask )
87
+
88
+ self .assertAllClose (result , expected_output )
0 commit comments