Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion fbgemm_gpu/experimental/hstu/CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,67 @@ if(NOT _resvar EQUAL 0)
message(FATAL_ERROR "generate_kernels.py failed:\n${_errvar}")
endif()

set(HSTU_CUDA_FLAGS "")
if (DEFINED ENV{HSTU_DISABLE_BACKWARD} AND "$ENV{HSTU_DISABLE_BACKWARD}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_BACKWARD")
endif()
if (DEFINED ENV{HSTU_DISABLE_DETERMINISTIC} AND "$ENV{HSTU_DISABLE_DETERMINISTIC}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_DETERMINISTIC")
endif()
if (DEFINED ENV{HSTU_DISABLE_LOCAL} AND "$ENV{HSTU_DISABLE_LOCAL}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_LOCAL")
endif()
if (DEFINED ENV{HSTU_DISABLE_CAUSAL} AND "$ENV{HSTU_DISABLE_CAUSAL}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_CAUSAL")
endif()
if (DEFINED ENV{HSTU_DISABLE_CONTEXT} AND "$ENV{HSTU_DISABLE_CONTEXT}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_CONTEXT")
endif()
if (DEFINED ENV{HSTU_DISABLE_TARGET} AND "$ENV{HSTU_DISABLE_TARGET}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_TARGET")
endif()
if (DEFINED ENV{HSTU_DISABLE_ARBITRARY} AND "$ENV{HSTU_DISABLE_ARBITRARY}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_ARBITRARY")
endif()
if (DEFINED ENV{HSTU_ARBITRARY_NFUNC})
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_ARBITRARY_NFUNC=$ENV{HSTU_ARBITRARY_NFUNC}")
else()
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_ARBITRARY_NFUNC=1")
endif()
if (DEFINED ENV{HSTU_DISABLE_RAB} AND "$ENV{HSTU_DISABLE_RAB}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_RAB")
endif()
if (DEFINED ENV{HSTU_DISABLE_DRAB} AND "$ENV{HSTU_DISABLE_DRAB}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_DRAB")
endif()
if (DEFINED ENV{HSTU_DISABLE_BF16} AND "$ENV{HSTU_DISABLE_BF16}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_BF16")
endif()
if (DEFINED ENV{HSTU_DISABLE_FP16} AND "$ENV{HSTU_DISABLE_FP16}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_FP16")
endif()
if (DEFINED ENV{HSTU_DISABLE_FP8} AND "$ENV{HSTU_DISABLE_FP8}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_FP8")
endif()
if (DEFINED ENV{HSTU_USE_E5M2_BWD} AND "$ENV{HSTU_USE_E5M2_BWD}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_USE_E5M2_BWD")
endif()
if (DEFINED ENV{HSTU_DISABLE_HDIM32} AND "$ENV{HSTU_DISABLE_HDIM32}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_HDIM32")
endif()
if (DEFINED ENV{HSTU_DISABLE_HDIM64} AND "$ENV{HSTU_DISABLE_HDIM64}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_HDIM64")
endif()
if (DEFINED ENV{HSTU_DISABLE_HDIM128} AND "$ENV{HSTU_DISABLE_HDIM128}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_HDIM128")
endif()
if (DEFINED ENV{HSTU_DISABLE_HDIM256} AND "$ENV{HSTU_DISABLE_HDIM256}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_HDIM256")
endif()
if (DEFINED ENV{HSTU_DISABLE_86OR89} AND "$ENV{HSTU_DISABLE_86OR89}" STREQUAL "TRUE")
set(HSTU_CUDA_FLAGS "${HSTU_CUDA_FLAGS} -DHSTU_DISABLE_86OR89")
endif()


# Collect HSTU Ampere source files
file(GLOB hstu_ampere_cpp_source_files
Expand Down Expand Up @@ -80,7 +141,8 @@ if(NOT hstu_cpp_source_files AND NOT hstu_cpp_source_files_gpu)
endif()

# Add specific NVCC flags for HSTU compilation
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --use_fast_math")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --use_fast_math ${HSTU_CUDA_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${HSTU_CUDA_FLAGS}")

################################################################################
# Build Shared Library
Expand Down
128 changes: 97 additions & 31 deletions fbgemm_gpu/experimental/hstu/README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,27 @@ Both HSTU-2 and HSTU-3 share the same API:

```python
def hstu_attn_varlen_func(
q, # (total_q, nheads, headdim)
k, # (total_k, nheads_k, headdim), nheads should be equal to nhead_k
v, # (total_k, nheads_k, headdim)
seq_offsets_q, # (batch_size + 1,), cumulative sequence lengths for q
seq_offsets_k, # (batch_size + 1,), cumulative sequence lengths for k/v
max_seqlen_q, # Maximum query sequence length
max_seqlen_k, # Maximum k/v sequence length
num_contexts=None, # (batch_size,), context tokens per batch
num_targets=None, # (batch_size,), target tokens per batch
target_group_size=1, # Number of target tokens per group
window_size=(-1, -1), # (left, right) for sliding window, -1 means infinite window size
alpha=1.0, # Scaling factor between add rab and silu
rab=None, # (batch_size, nhead_rab, max_seqlen_k, max_seqlen_k), relative attention bias
# nheads should be divisible by nhead_rab
has_drab=False, # Whether to apply drab
is_delta_q=False, # Whether to apply delta_q
descale_q=None, # (1,), descaling factor for query
descale_k=None, # (1,), descaling factor for key
descale_v=None, # (1,), descaling factor for value
descale_do=None, # (1,), descaling factor for do
q, # (total_q, nheads, headdim)
k, # (total_k, nheads_k, headdim), nheads should be equal to nhead_k
v, # (total_k, nheads_k, headdim)
cu_seqlens_q, # (batch_size + 1,), cumulative sequence lengths for q
cu_seqlens_k, # (batch_size + 1,), cumulative sequence lengths for k/v
max_seqlen_q, # Maximum query sequence length
max_seqlen_k, # Maximum k/v sequence length
num_contexts=None, # (batch_size,), context tokens per batch
num_targets=None, # (batch_size,), target tokens per batch
target_group_size=1, # Number of target tokens per group
window_size=(-1, -1), # (left, right) for sliding window, -1 means infinite window size
alpha=1.0, # Scaling factor between add rab and silu
rab=None, # (batch_size, nhead_rab, max_seqlen_k, max_seqlen_k), relative attention bias
# nheads should be divisible by nhead_rab
has_drab=False, # Whether to apply drab
kv_cache=None, # (page_num, 2, page_size, nheads, headdim), key and value paged cache.
page_offsets=None, # (batch_size + 1,). The cumulative sequence lengths of the page_ptr in the batch, used to index into kv_cache.
page_ids=None, # (page_offsets[-1],). The ids of the pages in the batch.
last_page_lens=None, # (batch_size,). The lengths of the last pages in the batch.
func=None, # (nheads, total_q + 256). Function to describe the mask shape in arbitrary mask.
quant_mode=-1, # int. Quantization mode.
)
```

Expand All @@ -44,12 +45,65 @@ def hstu_attn_varlen_func(
cd fbgemm_gpu/

# Install HSTU-Ampere
export HSTU_DISABLE_BACKWARD=FLASE; \
export HSTU_DISABLE_LOCAL=FALSE; \
export HSTU_DISABLE_CAUSAL=FALSE; \
export HSTU_DISABLE_CONTEXT=FALSE; \
export HSTU_DISABLE_TARGET=FALSE; \
export HSTU_DISABLE_ARBITRARY=FALSE; \
export HSTU_ARBITRARY_NFUNC=3; \
export HSTU_DISABLE_RAB=FALSE; \
export HSTU_DISABLE_DRAB=FALSE; \
export HSTU_DISABLE_BF16=FALSE; \
export HSTU_DISABLE_FP16=TRUE; \
export HSTU_DISABLE_HDIM32=FALSE; \
export HSTU_DISABLE_HDIM64=FALSE; \
export HSTU_DISABLE_HDIM128=FALSE; \
export HSTU_DISABLE_HDIM256=FALSE; \
export HSTU_DISABLE_DETERMINISTIC=TRUE; \
export HSTU_DISABLE_86OR89=TRUE; \
python setup.py install --build-target=hstu -DTORCH_CUDA_ARCH_LIST="8.0"

# Install HSTU-Hopper
export HSTU_DISABLE_BACKWARD=FALSE; \
export HSTU_DISABLE_LOCAL=FALSE; \
export HSTU_DISABLE_CAUSAL=FALSE; \
export HSTU_DISABLE_CONTEXT=FALSE; \
export HSTU_DISABLE_TARGET=FALSE; \
export HSTU_DISABLE_ARBITRARY=FALSE; \
export HSTU_ARBITRARY_NFUNC=3; \
export HSTU_DISABLE_RAB=FALSE; \
export HSTU_DISABLE_DRAB=FALSE; \
export HSTU_DISABLE_BF16=FALSE; \
export HSTU_DISABLE_FP16=TRUE; \
export HSTU_DISABLE_FP8=FALSE; \
export HSTU_USE_E5M2_BWD=FALSE; \
export HSTU_DISABLE_HDIM32=FALSE; \
export HSTU_DISABLE_HDIM64=FALSE; \
export HSTU_DISABLE_HDIM128=FALSE; \
export HSTU_DISABLE_HDIM256=FALSE; \
python setup.py install --build-target=hstu -DTORCH_CUDA_ARCH_LIST="9.0"

# Install both
export HSTU_DISABLE_BACKWARD=FALSE; \
export HSTU_DISABLE_LOCAL=FALSE; \
export HSTU_DISABLE_CAUSAL=FALSE; \
export HSTU_DISABLE_CONTEXT=FALSE; \
export HSTU_DISABLE_TARGET=FALSE; \
export HSTU_DISABLE_ARBITRARY=FALSE; \
export HSTU_ARBITRARY_NFUNC=3; \
export HSTU_DISABLE_RAB=FALSE; \
export HSTU_DISABLE_DRAB=FALSE; \
export HSTU_DISABLE_BF16=FALSE; \
export HSTU_DISABLE_FP16=TRUE; \
export HSTU_DISABLE_FP8=FALSE; \
export HSTU_USE_E5M2_BWD=FALSE; \
export HSTU_DISABLE_HDIM32=FALSE; \
export HSTU_DISABLE_HDIM64=FALSE; \
export HSTU_DISABLE_HDIM128=FALSE; \
export HSTU_DISABLE_HDIM256=FALSE; \
export HSTU_DISABLE_DETERMINISTIC=TRUE; \
export HSTU_DISABLE_86OR89=TRUE; \
python setup.py install --build-target=hstu -DTORCH_CUDA_ARCH_LIST="8.0 9.0"

# If you don't add -DTORCH_CUDA_ARCH_LIST, the default is "8.0 9.0".
Expand All @@ -61,6 +115,12 @@ python setup.py install --build-target=hstu -DTORCH_CUDA_ARCH_LIST="8.0 9.0"
- **Supported GPUs**: Ampere, Ada, Hopper (without Hopper-specific features)
- **Data types**: FP16, BF16
- **Head dimensions**: 32, 64, 128, 256
- **Paged attention**:
* Only support one mask shown in following figure.
* Sequence length of each `k` and `v` is the same as `q`. So the first few items of `k` and `v` (which represent new history) are invalid, and the actual values of these elements are in the paged `kv_cache`.
* Page size only supports 32 and 64.

![paged_kv](img/paged_kv.png)
- **Attention masks**:
* No mask
* Local mask (0 <= window_size_left < max_seqlen_k or 0 <= window_size_right < max_seqlen_k)
Expand All @@ -69,23 +129,29 @@ python setup.py install --build-target=hstu -DTORCH_CUDA_ARCH_LIST="8.0 9.0"
* Target mask + causal mask
* Context mask + causal mask + target mask

![Context+causal+target mask](context_causal_target.png)
* Delta_q ('seqlen_q <= seqlen_k')
![Context+causal+target mask](img/context_causal_target.png)
* Delta_q (seqlen_q < seqlen_k would automatically use delta_q)
* Delta_q + local mask

![Delta_q+local mask](deltaq_local.png)
![Delta_q+local mask](img/deltaq_local.png)
* Delta_q + causal mask

![Delta_q+causal mask](deltaq_causal.png)
![Delta_q+causal mask](img/deltaq_causal.png)

* Arbitrary mask

Use array to determine the mask situation for each row of S. We provide three examples, please refer to [line 398](test/hstu_test.py#L398), [line 430](test/hstu_test.py#L430), and [line 463](test/hstu_test.py#L463) of hstu_test.py.

### HSTU-Hopper
- **Supported GPUs**: Hopper only (H100, H20)
- **Data types**: FP16, BF16 (forward and backward), FP8 (forward only)
- **Data types**: FP16, BF16, and FP8
- **Head dimensions**: 32, 64, 128, 256 for FP16/BF16; 64, 128, 256 for FP8
- **Attention masks**:
* For FP16/BF16, same as HSTU-2
* For FP8, only supports:
+ No mask
+ Causal mask
+ Local mask
- **Attention masks**: Same as HSTU-Ampere
- **Quantization mode**: Five modes.
* quant_mode == 0: Cast to fp8 directly.
* quant_mode == 1: 1xDIM && 128x1.
* quant_mode == 2: Per block quantization. Block size is in [get_bm_and_bn_block_size_fwd](hstu/cuda_hstu_attention.py#L93) and [get_bm_and_bn_block_size_bwd](hstu/cuda_hstu_attention.py#L112).
* quant_mode == 3: Per head quantization. Shape of q_descale is (batch_size, nheads).
* quant_mode == 4: Per tensor quantization. Shape of q_descale is (batch_size).
* quant_mode == 5: Per batch quantization. Shape of q_descale is (1).
- **Note**: Only undeterministic backward implementation
6 changes: 6 additions & 0 deletions fbgemm_gpu/experimental/hstu/hstu/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,10 @@
cuda_hstu_attn_varlen,
hstu_attn_varlen_func,
HstuAttnVarlenFunc,
hstu_attn_qkvpacked_func,
quantize_for_two_directions,
quantize_for_block_scale,
get_bm_and_bn_block_size_fwd,
get_bm_and_bn_block_size_bwd,
quantize_for_head_batch_tensor,
)
Loading
Loading