Skip to content

Bug for _can_use_flash_attention #21507

@pass-lin

Description

@pass-lin

The GitHub source code defines a function _can_use_flash_attention(...) that attempts to verify whether Flash Attention is available.
However, starting with JAX 0.6.2 (the version recommended by the requirements), the signature of the internal helper check_layout was changed to:

def check_layout(query, key, value, bias, q_seqlen, kv_seqlen,
                 q_offsets, kv_offsets, page_table_k, page_table_v, layout):

In the current implementation, _can_use_flash_attention still calls check_layout with the old argument list:

check_layout(
    query,
    key,
    value,
    bias,
    q_seqlen=None,
    kv_seqlen=None,
    layout=_normalize_layout("BTNH"),
)

Because the required positional arguments q_offsets, kv_offsets, page_table_k, and page_table_v are missing, this call always raises a TypeError.
As a result, _can_use_flash_attention catches the exception and always returns False, effectively preventing JAX from using Flash Attention in the backend.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions