-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Closed
Labels
Description
The GitHub source code defines a function _can_use_flash_attention(...)
that attempts to verify whether Flash Attention is available.
However, starting with JAX 0.6.2 (the version recommended by the requirements), the signature of the internal helper check_layout
was changed to:
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen,
q_offsets, kv_offsets, page_table_k, page_table_v, layout):
In the current implementation, _can_use_flash_attention
still calls check_layout
with the old argument list:
check_layout(
query,
key,
value,
bias,
q_seqlen=None,
kv_seqlen=None,
layout=_normalize_layout("BTNH"),
)
Because the required positional arguments q_offsets
, kv_offsets
, page_table_k
, and page_table_v
are missing, this call always raises a TypeError
.
As a result, _can_use_flash_attention
catches the exception and always returns False
, effectively preventing JAX from using Flash Attention in the backend.