Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/actions/setup-build-cuda/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ runs:
print(sys.version)
cushort = "${{ inputs.toolkit_short_version }}"
# Version uploaded to pypi (rather than PyTorch s3)
TORCH_CUDA_DEFAULT = "128" # since pytorch 2.8.0
TORCH_CUDA_DEFAULT = "130" # since pytorch 2.9.0
# https://github.com/Jimver/cuda-toolkit/blob/master/src/links/linux-links.ts
full_version, install_script = {
"129": ("12.9.0", "https://developer.download.nvidia.com/compute/cuda/12.9.1/local_installers/cuda_12.9.1_575.57.08_linux.run"),
"130": ("13.0.1", "https://developer.download.nvidia.com/compute/cuda/13.0.1/local_installers/cuda_13.0.1_580.82.07_linux.run"),
"128": ("12.8.1", "https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_570.124.06_linux.run"),
# (Build with nvcc 12.8 on linux even when building for 12.6 to avoid seg fault in Flash3 build)
"126": ("12.8.1", "https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_570.124.06_linux.run"),
Expand All @@ -52,7 +52,7 @@ runs:
- name: Install cuda
if: runner.os == 'Windows' && inputs.toolkit_type == 'cuda'
id: cuda-toolkit
uses: Jimver/[email protected].24
uses: Jimver/[email protected].27
with:
cuda: ${{ steps.cuda_info.outputs.CUDA_VERSION }}
method: network
Expand Down
2 changes: 1 addition & 1 deletion .github/actions/setup-env-build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ runs:

CONDA_INSTALL_CMD = "micromamba create python=${{ inputs.python }} zlib pip ninja ccache=4.8 -c conda-forge -q -y"

conda_env_key = CONDA_INSTALL_CMD + "[cu129][v2]"
conda_env_key = CONDA_INSTALL_CMD + "[cu130][v2]"
for file in sorted(glob.glob("requirement*.txt")):
conda_env_key += f"\n########## {file}\n"
conda_env_key += Path(file).read_text()
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ jobs:
PYTHON_VERSION = "3.9"
# NOTE: Don't forget to update `upload_pt`'s matrix
# when changing the CUDA/ROCM versions below!
CU_VERSIONS = ['126', '128', '129']
CU_VERSIONS = ['126', '128', '130']
ROCM_VERSIONS = ['6.4']

include = []
for os in ['8-core-ubuntu', 'windows-8-core']:
for torch_version in ['2.8.0']:
for torch_version in ['2.9.0']:
# CUDA builds
for cuda_short_version in CU_VERSIONS:
if cuda_short_version < "124" and "windows" in os:
Expand Down Expand Up @@ -88,7 +88,7 @@ jobs:
uses: ./.github/workflows/wheels_upload_pip.yml
with:
twine_username: __token__
filter: "*torch2.8.0+cu128*"
filter: "*torch2.9.0+cu130*"
execute: ${{ github.repository == 'facebookresearch/xformers' && github.event_name != 'pull_request' }}
secrets:
twine_password: ${{ secrets.PYPI_TOKEN }}
Expand All @@ -101,7 +101,7 @@ jobs:
suffix:
- cu126
- cu128
- cu129
- cu130
- rocm6.4
uses: ./.github/workflows/wheels_upload_s3.yml
with:
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/wheels_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ jobs:
run:
shell: bash
steps:
- if: contains(inputs.toolkit_type, 'cuda') && fromJSON(inputs.toolkit_short_version) >= 120
- if: contains(inputs.toolkit_type, 'cuda') && fromJSON(inputs.toolkit_short_version) >= 120 && fromJSON(inputs.toolkit_short_version) < 130
run: |
echo "TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST 9.0a" >> ${GITHUB_ENV}
echo "TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST 8.0 9.0a" >> ${GITHUB_ENV}

- if: contains(inputs.toolkit_type, 'cuda') && fromJSON(inputs.toolkit_short_version) >= 130
run: |
echo "TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST 8.0 9.0a 10.0f 11.0f 12.0f" >> ${GITHUB_ENV}

- if: runner.os == 'Windows'
run: git config --system core.longpaths true
Expand Down
14 changes: 11 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def get_flash_attention2_nvcc_archs_flags(cuda_version: int):
return []
# Figure out default archs to target
DEFAULT_ARCHS_LIST = ""
if cuda_version >= 1208:
if cuda_version >= 1300:
DEFAULT_ARCHS_LIST = "8.0;8.6;9.0;10.0f;11.0f;12.0f"
elif cuda_version >= 1208:
DEFAULT_ARCHS_LIST = "8.0;8.6;9.0;10.0;12.0"
elif cuda_version >= 1108:
DEFAULT_ARCHS_LIST = "8.0;8.6;9.0"
Expand Down Expand Up @@ -283,7 +285,7 @@ def get_flash_attention3_nvcc_archs_flags(cuda_version: int):
return []
archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
if archs_list is None:
if torch.cuda.get_device_capability("cuda") != (9, 0):
if torch.cuda.get_device_capability("cuda") != (9, 0) and torch.cuda.get_device_capability("cuda") != (8, 0):
return []
archs_list = "8.0 9.0a"
nvcc_archs_flags = []
Expand Down Expand Up @@ -328,6 +330,12 @@ def get_flash_attention3_extensions(cuda_version: int, extra_compile_args):
# for explicit values hence causing us to build these kernels twice.
sources = [s for s in sources if ("hdimall" not in s and "softcapall" not in s)]

# Avoid duplicate pybind module definitions (PyInit__C)
# Keep the stable ABI entrypoint when building with py_limited_api,
# and exclude the non-stable one.
if any(s.endswith("hopper/flash_api_stable.cpp") for s in sources):
sources = [s for s in sources if not s.endswith("hopper/flash_api.cpp")]

# We don't care/expose softcap and fp8 and paged attention,
# hence we disable them for faster builds.
DISABLED_CAPABILITIES = (
Expand Down Expand Up @@ -512,7 +520,7 @@ def get_extensions():
if cuda_version >= 1102:
nvcc_flags += [
"--threads",
"4",
os.getenv("NVCC_THREADS", "1"),
"--ptxas-options=-v",
]
if sys.platform == "win32":
Expand Down
2 changes: 1 addition & 1 deletion third_party/cutlass
Submodule cutlass updated 1225 files
2 changes: 1 addition & 1 deletion third_party/flash-attention
Submodule flash-attention updated 88 files
+225 −0 .github/workflows/_build.yml
+47 −0 .github/workflows/build.yml
+30 −146 .github/workflows/publish.yml
+411 −0 benchmarks/benchmark_attn.py
+0 −30 benchmarks/benchmark_causal.py
+46 −31 benchmarks/benchmark_flash_attention.py
+1 −1 csrc/composable_kernel
+1 −1 csrc/cutlass
+1 −1 csrc/flash_attn/flash_api.cpp
+3 −0 csrc/flash_attn_ck/mha_fwd.cpp
+2 −1 csrc/flash_attn_ck/mha_fwd_kvcache.cpp
+8 −4 csrc/flash_attn_ck/mha_varlen_fwd.cpp
+0 −14 csrc/ft_attention/README.md
+0 −257 csrc/ft_attention/cuda_bf16_fallbacks.cuh
+0 −23 csrc/ft_attention/cuda_bf16_wrapper.h
+0 −149 csrc/ft_attention/decoder_masked_multihead_attention.cu
+0 −192 csrc/ft_attention/decoder_masked_multihead_attention.h
+0 −1,619 csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
+0 −2,017 csrc/ft_attention/decoder_masked_multihead_attention_utils.h
+0 −231 csrc/ft_attention/ft_attention.cpp
+0 −153 csrc/ft_attention/setup.py
+0 −148 csrc/fused_softmax/fused_softmax.cpp
+0 −528 csrc/fused_softmax/scaled_masked_softmax.h
+0 −121 csrc/fused_softmax/scaled_masked_softmax_cuda.cu
+0 −529 csrc/fused_softmax/scaled_upper_triang_masked_softmax.h
+0 −98 csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu
+0 −50 csrc/fused_softmax/setup.py
+0 −20 csrc/fused_softmax/type_shim.h
+0 −40 csrc/rotary/rotary.cpp
+0 −45 csrc/rotary/rotary_cuda.cu
+0 −126 csrc/rotary/setup.py
+0 −14 csrc/xentropy/README.md
+0 −59 csrc/xentropy/interface.cpp
+0 −139 csrc/xentropy/setup.py
+0 −758 csrc/xentropy/xentropy_kernel.cu
+1 −1 flash_attn/__init__.py
+0 −0 flash_attn/cute/README.md
+13 −0 flash_attn/cute/__init__.py
+41 −17 flash_attn/cute/ampere_helpers.py
+606 −0 flash_attn/cute/blackwell_helpers.py
+63 −32 flash_attn/cute/block_info.py
+97 −0 flash_attn/cute/fast_math.py
+56 −56 flash_attn/cute/flash_bwd.py
+242 −22 flash_attn/cute/flash_bwd_postprocess.py
+46 −19 flash_attn/cute/flash_bwd_preprocess.py
+1,392 −0 flash_attn/cute/flash_bwd_sm90.py
+728 −456 flash_attn/cute/flash_fwd.py
+644 −0 flash_attn/cute/flash_fwd_combine.py
+1,909 −0 flash_attn/cute/flash_fwd_sm100.py
+31 −4 flash_attn/cute/hopper_helpers.py
+382 −58 flash_attn/cute/interface.py
+223 −37 flash_attn/cute/mask.py
+289 −0 flash_attn/cute/mma_sm100_desc.py
+13 −0 flash_attn/cute/named_barrier.py
+18 −11 flash_attn/cute/pack_gqa.py
+23 −33 flash_attn/cute/pipeline.py
+50 −0 flash_attn/cute/pyproject.toml
+33 −2 flash_attn/cute/seqlen_info.py
+214 −23 flash_attn/cute/softmax.py
+404 −0 flash_attn/cute/testing.py
+503 −0 flash_attn/cute/tile_scheduler.py
+378 −124 flash_attn/cute/utils.py
+1 −1 flash_attn/flash_attn_interface.py
+0 −201 flash_attn/fused_softmax.py
+4 −1 flash_attn/pyproject.toml
+17 −6 flash_attn/utils/testing.py
+1 −1 hopper/benchmark_attn.py
+7 −4 hopper/epilogue_bwd.hpp
+7 −1 hopper/flash.h
+78 −26 hopper/flash_api.cpp
+1,978 −0 hopper/flash_api_stable.cpp
+17 −8 hopper/flash_attn_interface.py
+8 −6 hopper/flash_bwd_launch_template.h
+8 −3 hopper/flash_fwd_combine_kernel.h
+1 −1 hopper/flash_fwd_combine_launch_template.h
+10 −7 hopper/flash_fwd_launch_template.h
+165 −39 hopper/flash_prepare_scheduler.cu
+11 −3 hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp
+39 −3 hopper/setup.py
+1 −1 hopper/sm90_pipeline_no_cluster.hpp
+23 −0 hopper/static_switch.h
+57 −27 hopper/test_flash_attn.py
+706 −0 hopper/test_flash_attn_bwd_determinism.py
+180 −72 hopper/tile_scheduler.hpp
+4 −3 hopper/tile_size.h
+94 −43 setup.py
+565 −40 tests/cute/test_flash_attn.py
+1 −1 tests/test_flash_attn_ck.py
8 changes: 4 additions & 4 deletions xformers/csrc/sparse24/sparse24_gemm_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ struct SparseRowwiseKernel<cutlass::float_e4m3_t> {
float,
ElementOut,
cutlass::layout::RowMajor,
1,
8,
ElementOut,
cutlass::layout::RowMajor,
1,
8,
cutlass::epilogue::TmaWarpSpecializedCooperative,
EpilogueEVT>::CollectiveOp;

Expand Down Expand Up @@ -176,10 +176,10 @@ struct SparseRowwiseKernel<cutlass::bfloat16_t> {
float,
ElementOut,
cutlass::layout::RowMajor,
1,
8,
ElementOut,
cutlass::layout::RowMajor,
1,
8,
cutlass::epilogue::TmaWarpSpecializedCooperative,
EpilogueEVT>::CollectiveOp;

Expand Down
2 changes: 1 addition & 1 deletion xformers/ops/fmha/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _get_use_fa3() -> bool:

def fa3_available() -> bool:
has_cuda = torch.version.cuda is not None
is_90a = has_cuda and torch.cuda.get_device_capability() >= (9, 0)
is_90a = has_cuda and (8, 0) <= torch.cuda.get_device_capability() <= (9, 0)
has_valid_flash3 = flash3._C_flashattention3 is not None # pyre-ignore[16]
return is_90a and has_valid_flash3

Expand Down
2 changes: 1 addition & 1 deletion xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@

FLASH_VERSION = flash_attn.__version__
FLASH_VER_MIN = parse_version("2.7.1")
FLASH_VER_LAST = parse_version("2.8.3") # last supported, inclusive
FLASH_VER_LAST = parse_version("2.8.4") # last supported, inclusive
flash_ver_parsed = parse_version(FLASH_VERSION)
if (
flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST
Expand Down
16 changes: 16 additions & 0 deletions xformers/ops/fmha/flash3.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,14 @@ class FwOp(AttentionFwOpBase):
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
device_type = d.query.device.type
if device_type == "cuda" and (torch.version.hip is None):
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability > cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
reasons.append(
f"requires device with capability == {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
f"but your GPU has capability {device_capability} (too new)"
)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
check_lastdim_alignment_stride1(reasons, "key", d.value, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
Expand Down Expand Up @@ -796,6 +804,14 @@ class BwOp(AttentionBwOpBase):
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
device_type = d.query.device.type
if device_type == "cuda" and (torch.version.hip is None):
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability > cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
reasons.append(
f"requires device with capability == {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
f"but your GPU has capability {device_capability} (too new)"
)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
check_lastdim_alignment_stride1(reasons, "key", d.value, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
Expand Down
Loading