Skip to content

Commit 54bdd19

Browse files
hertschuhrtg0795
authored andcommittedMar 26, 2025
Fix incorrect argument in JAX flash attention. (#21014)
The mask is named `array` in `NumpyMask`.
1 parent 537c524 commit 54bdd19

File tree

1 file changed

+1
-1
lines changed
  • keras/src/backend/jax

1 file changed

+1
-1
lines changed
 

‎keras/src/backend/jax/nn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ def wrap_flash_attention(
11241124
)
11251125

11261126
if custom_mask is not None:
1127-
mask = splash_attention_mask.NumpyMask(mask=custom_mask)
1127+
mask = splash_attention_mask.NumpyMask(array=custom_mask)
11281128

11291129
else:
11301130
mask = splash_attention_mask.CausalMask(

0 commit comments

Comments
 (0)
Please sign in to comment.