diff --git a/gpt_builders.py b/gpt_builders.py index 4fe832028b..9fa1aff72c 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -118,6 +118,7 @@ def _get_transformer_layer_spec(use_te, config): moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, qk_l2_norm=args.qk_l2_norm, use_kitchen=config.use_kitchen, + fallback_to_eager_attn=config.fallback_to_eager_attn, ) else: return get_gpt_layer_local_spec( diff --git a/megatron/core/extensions/transformer_engine_spec_provider.py b/megatron/core/extensions/transformer_engine_spec_provider.py index c630671ad0..6f8947078b 100644 --- a/megatron/core/extensions/transformer_engine_spec_provider.py +++ b/megatron/core/extensions/transformer_engine_spec_provider.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import warnings from typing import Optional, Tuple @@ -17,6 +17,7 @@ from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.models.backends import BackendSpecProvider from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP from megatron.core.utils import get_te_version, is_te_min_version @@ -25,6 +26,10 @@ class TESpecProvider(BackendSpecProvider): """A protocol for providing the submodules used in Spec building.""" + def __init__(self, fallback_to_eager_attn: bool = False): + super().__init__() + self.fallback_to_eager_attn = fallback_to_eager_attn + def linear(self) -> type: """Which linear module TE backend uses""" return TELinear @@ -56,6 +61,8 @@ def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> type: def core_attention(self) -> type: """Which module to use for attention""" + if self.fallback_to_eager_attn: + return DotProductAttention return TEDotProductAttention def grouped_mlp_modules( diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 196c21ebe4..c5c9caa3d6 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -10,6 +10,7 @@ ) from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType, LayerType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules @@ -85,6 +86,7 @@ def get_gpt_layer_with_transformer_engine_spec( use_te_op_fuser: Optional[bool] = False, use_kitchen: bool = False, use_te_activation_func: bool = False, + fallback_to_eager_attn: bool = False, ) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). @@ -116,13 +118,15 @@ def get_gpt_layer_with_transformer_engine_spec( if use_kitchen: assert HAVE_KITCHEN - backend: BackendSpecProvider = KitchenSpecProvider(fallback=TESpecProvider()) + backend: BackendSpecProvider = KitchenSpecProvider( + fallback=TESpecProvider(fallback_to_eager_attn=fallback_to_eager_attn) + ) if use_te_op_fuser: raise AssertionError("use_te_op_fuser not compatible with using kitchen in mlp.") if use_te_activation_func: raise AssertionError("use_te_activation_func not compatible with using kitchen.") else: - backend = TESpecProvider() + backend = TESpecProvider(fallback_to_eager_attn=fallback_to_eager_attn) sharded_state_dict_keys_map = {} @@ -135,6 +139,7 @@ def get_gpt_layer_with_transformer_engine_spec( multi_latent_attention=multi_latent_attention, mla_down_proj_use_column_parallel=False, normalization=normalization, + fallback_to_eager_attn=fallback_to_eager_attn, ) mlp = get_mlp_module_spec_for_backend( @@ -214,6 +219,7 @@ def get_gpt_layer_local_spec( multi_latent_attention=multi_latent_attention, mla_down_proj_use_column_parallel=True, normalization=normalization, + fallback_to_eager_attn=False, ) mlp = get_mlp_module_spec_for_backend( @@ -278,6 +284,7 @@ def get_attention_module_spec_for_backend( multi_latent_attention: Optional[bool] = False, mla_down_proj_use_column_parallel: Optional[bool] = False, normalization: Optional[str] = None, + fallback_to_eager_attn: Optional[bool] = False, ) -> ModuleSpec: """Helper function to get module spec for Attention""" @@ -292,6 +299,7 @@ def get_attention_module_spec_for_backend( rms_norm = normalization == "RMSNorm" qk_norm = backend.layer_norm(rms_norm=rms_norm, for_qk=True) + core_attention = backend.core_attention() if not fallback_to_eager_attn else DotProductAttention if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." linear_q_down_proj = ( @@ -328,7 +336,7 @@ def get_attention_module_spec_for_backend( linear_q_up_proj=linear_q_up_proj, linear_kv_down_proj=linear_kv_down_proj, linear_kv_up_proj=linear_kv_up_proj, - core_attention=backend.core_attention(), + core_attention=core_attention, linear_proj=backend.row_parallel_linear(), q_layernorm=qk_norm, kv_layernorm=qk_norm, @@ -352,7 +360,7 @@ def get_attention_module_spec_for_backend( params={"attn_mask_type": AttnMaskType.causal}, submodules=SelfAttentionSubmodules( linear_qkv=linear_qkv, - core_attention=backend.core_attention(), + core_attention=core_attention, linear_proj=backend.row_parallel_linear(), q_layernorm=qk_norm, k_layernorm=qk_norm, @@ -499,6 +507,7 @@ def get_gpt_decoder_layer_specs( if use_transformer_engine: layer_norm_impl = TENorm get_layer_spec_kwargs["use_te_activation_func"] = config.use_te_activation_func + get_layer_spec_kwargs['fallback_to_eager_attn'] = config.fallback_to_eager_attn get_layer_spec_fn = get_gpt_layer_with_transformer_engine_spec else: layer_norm_impl = LNImpl @@ -652,9 +661,11 @@ def get_gpt_mtp_block_spec( """GPT Multi-Token Prediction (MTP) block spec.""" if use_transformer_engine: backend: BackendSpecProvider = ( - KitchenSpecProvider(fallback=TESpecProvider()) + KitchenSpecProvider( + fallback=TESpecProvider(fallback_to_eager_attn=config.fallback_to_eager_attn) + ) if config.use_kitchen - else TESpecProvider() + else TESpecProvider(fallback_to_eager_attn=config.fallback_to_eager_attn) ) else: backend = ( diff --git a/megatron/core/transformer/dot_product_attention.py b/megatron/core/transformer/dot_product_attention.py index f3711c86eb..bef82c1028 100644 --- a/megatron/core/transformer/dot_product_attention.py +++ b/megatron/core/transformer/dot_product_attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import math @@ -12,6 +12,9 @@ from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.dot_product_attention_context_parallel import ( + AttentionFuncionWithContextParallel, +) from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig @@ -54,9 +57,12 @@ def __init__( self.config: TransformerConfig = config - assert ( - self.config.context_parallel_size == 1 - ), "Context parallelism is only supported by TEDotProductAttention!" + if self.config.context_parallel_size > 1: + assert attention_dropout is None and self.config.attention_dropout == 0.0, ( + f'DotProductAttention with context parallelism does not support attention dropout,' + f' but got {self.config.context_parallel_size=},' + f' {attention_dropout=}, and {self.config.attention_dropout=}.' + ) self.layer_number = max(1, layer_number) self.attn_mask_type = attn_mask_type @@ -172,6 +178,19 @@ def forward( self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2 ) + if self.config.context_parallel_size > 1: + output = AttentionFuncionWithContextParallel.apply( + query, + key, + value, + attention_mask, + self.config.attention_dropout, + self.softmax_scale, + parallel_state.get_context_parallel_group(), + ) + output = output.view(query.shape[0], query.shape[1], self.hidden_size_per_partition) + return output + # [b, np, sq, sk] output_size = (query.size(1), query.size(2), query.size(0), key.size(0)) diff --git a/megatron/core/transformer/dot_product_attention_context_parallel.py b/megatron/core/transformer/dot_product_attention_context_parallel.py new file mode 100644 index 0000000000..89659a1d74 --- /dev/null +++ b/megatron/core/transformer/dot_product_attention_context_parallel.py @@ -0,0 +1,342 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +# Some of this code was adopted from https://github.com/zhuzilin/ring-flash-attention/ +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.nn import functional as F + +try: + import einops + + HAVE_EINOPS = True +except ImportError: + HAVE_EINOPS = False + + +@torch.no_grad +def eager_attn_fwd(q, k, v, attn_bias, sinks, scale, dropout): + """Forward pass for eager attention""" + + # Rearrange query, key, value to (b, h, s, d) + b, sq, h, d = q.shape + sk = k.shape[1] + _q = einops.rearrange(q, 'b s h d -> b h s d') + _k = einops.rearrange(k, 'b s h d -> b h d s') + _v = einops.rearrange(v, 'b s h d -> b h s d') + + # Compute attention weights + attn_w = torch.matmul(_q, _k) * scale + attn_w = attn_w + attn_bias + + # Add sinks to attention weights + if sinks is None: + logits = attn_w + else: + _sinks = sinks.reshape(1, h, 1, 1).expand(b, -1, sq, 1) + logits = torch.cat([attn_w, _sinks], dim=-1) + + # Compute attention scores + probs = F.softmax(logits, dim=-1, dtype=logits.dtype) + if sinks is None: + attn_w = probs + else: + attn_w = probs[..., :-1] # Drop the sink + + # Compute attention output + attn_output = torch.matmul(attn_w, _v) + attn_output = einops.rearrange(attn_output, 'b h s d -> b s h d') + attn_output = attn_output.contiguous() + + return attn_output, probs + + +@torch.no_grad +def eager_attn_bwd(q, k, v, attn_bias, sinks, scale, dropout, attn_output, probs, grad_output): + """Backward pass for eager attention""" + + # Rearrange query, key, value to (b, h, s, d) + b, sq, h, d = q.shape + sk = k.shape[1] + _q_T = einops.rearrange(q, 'b s h d -> b h d s') + _k_T = einops.rearrange(k, 'b s h d -> b h s d') + _v_T = einops.rearrange(v, ' b s h d -> b h d s') + + # Backward pass for score @ value + if sinks is None: + attn_w = probs + else: + attn_w = probs[..., :-1] # Drop the sink + grad_output = einops.rearrange(grad_output, 'b s h d -> b h s d') + attn_w_T = einops.rearrange(attn_w, ' b h sq sk -> b h sk sq') + grad__v = torch.matmul(attn_w_T, grad_output) + grad_attn_w = torch.matmul(grad_output, _v_T) + + # Backward pass for softmax + if sinks is None: + grad_probs = grad_attn_w + else: + dummy = torch.zeros((b, h, sq, 1), device=q.device, dtype=q.dtype) + grad_probs = torch.cat([grad_attn_w, dummy], dim=3) + del grad_attn_w + grad_logits = torch._softmax_backward_data( + grad_probs, probs, -1, probs.dtype + ) # [b, h, sq, sk+1] + + # Backward pass for adding sinks + if sinks is None: + grad_sinks = None + grad_attn_w = grad_logits + else: + grad__sinks = grad_logits[:, :, :, -1] # [b, h, sq] + grad_sinks = einops.rearrange(grad__sinks, 'b h s -> h (b s)').sum(-1) + grad_attn_w = grad_logits[:, :, :, :-1].contiguous() # [b, h, sq, sk] + + # Backward pass for q @ K^T + grad_attn_w *= scale + grad__q = torch.matmul(grad_attn_w, _k_T) + grad__k = torch.matmul(_q_T, grad_attn_w) + + # Rearrange grads to (b, s, h, d) + grad_v = einops.rearrange(grad__v, 'b h s d -> b s h d') + grad_k = einops.rearrange(grad__k, 'b h d s -> b s h d') + grad_q = einops.rearrange(grad__q, 'b h s d -> b s h d') + return grad_q, grad_k, grad_v, grad_sinks + + +class AllGatherComm: + """All gather communication with async operations""" + + def __init__(self, group=None) -> None: + self.group = group + self.handles = [] + + def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): + '''All gather the input tensor to the output tensor''' + + if self.group is None: + output_tensor.copy_(input_tensor) + else: + handle = torch.distributed.all_gather_into_tensor( + output_tensor, input_tensor, group=self.group, async_op=True + ) + self.handles.append(handle) + + def wait(self): + '''Wait for all gather operations to complete''' + + if self.group is not None: + for handle in self.handles: + handle.wait() + self.handles = [] + + +def to_zz_mask_attn_bias(attention_mask, cp_size, nheads, nheads_k, heads_k_stride, device, dtype): + '''Convert the attention mask to the attention bias''' + + if cp_size == 1: + zz_mask = attention_mask + else: + chunked = attention_mask.chunk(dim=3, chunks=cp_size * 2) + zz_mask = [_x for _p in zip(chunked[:cp_size], reversed(chunked[cp_size:])) for _x in _p] + zz_mask = torch.cat(zz_mask, dim=3) + attn_bias = torch.zeros(zz_mask.shape, device=device, dtype=dtype) + attn_bias.masked_fill_(zz_mask, float('-inf')) + attn_bias = attn_bias.expand(-1, heads_k_stride * (nheads // nheads_k), -1, -1) + return attn_bias + + +class AttentionFuncionWithContextParallel(torch.autograd.Function): + """Native attention function with context parallelism.""" + + @staticmethod + def forward(ctx, q, k, v, attention_mask, attention_dropout, softmax_scale, pg): + '''Forward pass for the native attention function with context parallelism''' + + # Assert einops exists + if not HAVE_EINOPS: + raise ImportError("einops is required by the attention CP but cannot be imported.") + + # Initialize communication group and constants + cp_size = 1 + if pg is not None: + cp_size = torch.distributed.get_world_size(pg) + comm = AllGatherComm(group=pg) + nheads = q.shape[2] + nheads_k = k.shape[2] + heads_k_stride = 1 + assert nheads % nheads_k == 0 and nheads_k % heads_k_stride == 0 + outs = [] + probs = [] + + # Initialize KV buffers + kv_buffer = torch.empty( + (2, k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]), + dtype=k.dtype, + device=k.device, + ) + kv_buffer_copy = torch.empty_like(kv_buffer) + + # All-gather first chunk of KV buffers + k_0 = k[:, :, :heads_k_stride].contiguous() + v_0 = v[:, :, :heads_k_stride].contiguous() + comm.all_gather(kv_buffer_copy[0], k_0) + comm.all_gather(kv_buffer_copy[1], v_0) + + # Prepare attention bias + attn_bias = to_zz_mask_attn_bias( + attention_mask, cp_size, nheads, nheads_k, heads_k_stride, q.device, q.dtype + ) + + # Iterate over heads + for i in range(0, nheads_k, heads_k_stride): + # Wait for previous all-gather to complete + comm.wait() + kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer + # All-gather the next portion of KV buffers if not the last iteration + if i < nheads_k - heads_k_stride: + kvsl = i + heads_k_stride + kvsr = kvsl + heads_k_stride + send_k = k[:, :, kvsl:kvsr].contiguous() + send_v = v[:, :, kvsl:kvsr].contiguous() + comm.all_gather(kv_buffer_copy[0], send_k) + comm.all_gather(kv_buffer_copy[1], send_v) + + # Prepare query, key, value for attention + q_i = q[:, :, i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] + k_i = kv_buffer[0] + v_i = kv_buffer[1] + + # Rearrange query, key, value to (b, s, h, d) + q_i = einops.rearrange(q_i, 's b h d -> b s h d') + k_i = einops.rearrange(k_i, 's b h d -> b s h d') + v_i = einops.rearrange(v_i, 's b h d -> b s h d') + + # Forward pass + out_i, probs_i = eager_attn_fwd( + q_i, k_i, v_i, attn_bias, None, softmax_scale, attention_dropout + ) + outs.append(out_i) + probs.append(probs_i) + + # Concatenate outputs and rearrange to (s, b, h, d) + out = torch.cat(outs, dim=2) + out = einops.rearrange(out, 'b s h d -> s b h d') + + # Save contexts for backward pass + ctx.save_for_backward(q, k, v, attention_mask, *outs, *probs) + ctx.dropout = attention_dropout + ctx.scale = softmax_scale + ctx.heads_k_stride = heads_k_stride # TODO make it configurable + ctx.pg = pg + + return out + + @staticmethod + def backward(ctx, dout): + '''Backward pass for the native attention function with context parallelism''' + + # Initialize or resume constants and communication group + q, k, v, attention_mask, *rest = ctx.saved_tensors + nheads = q.shape[2] + nheads_k = k.shape[2] + heads_k_stride = ctx.heads_k_stride + assert nheads_k % heads_k_stride == 0 + outs = rest[: nheads_k // heads_k_stride] + probs = rest[nheads_k // heads_k_stride :] + pg = ctx.pg + cp_size = 1 + if pg is not None: + cp_size = torch.distributed.get_world_size(pg) + comm = AllGatherComm(group=pg) + + # Initialize KV buffers + kv_buffer = torch.empty( + (2, k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]), + dtype=k.dtype, + device=k.device, + ) + kv_buffer_copy = torch.empty_like(kv_buffer) + + # All-gather first chunk of KV buffers + dq = [] + dk = [] + dv = [] + k_0 = k[:, :, :heads_k_stride].contiguous() + v_0 = v[:, :, :heads_k_stride].contiguous() + comm.all_gather(kv_buffer_copy[0], k_0) + comm.all_gather(kv_buffer_copy[1], v_0) + + # Prepare attention bias + attn_bias = to_zz_mask_attn_bias( + attention_mask, cp_size, nheads, nheads_k, heads_k_stride, q.device, q.dtype + ) + + # Iterate over heads + for i in range(0, nheads_k, heads_k_stride): + # Slice query and output for this iteration + q_slice = slice(i * nheads // nheads_k, (i + heads_k_stride) * nheads // nheads_k) + q_i = q[:, :, q_slice] + dout_i = dout[:, :, q_slice] + + # Wait for previous all-gather to complete + comm.wait() + kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer + + # All-gather the next portion of KV buffers if not the last iteration + if i < nheads_k - heads_k_stride: + kvsl = i + heads_k_stride + kvsr = kvsl + heads_k_stride + send_k = k[:, :, kvsl:kvsr].contiguous() + send_v = v[:, :, kvsl:kvsr].contiguous() + comm.all_gather(kv_buffer_copy[0], send_k) + comm.all_gather(kv_buffer_copy[1], send_v) + + # Prepare key, value for attention + k_i = kv_buffer[0] + v_i = kv_buffer[1] + + # Rearrange query, key, value to (b, s, h, d) + q_i = einops.rearrange(q_i, 's b h d -> b s h d') + k_i = einops.rearrange(k_i, 's b h d -> b s h d') + v_i = einops.rearrange(v_i, 's b h d -> b s h d') + dout_i = einops.rearrange(dout_i, 's b h d -> b s h d') + + # Backward pass + dq_i, _dk_i, _dv_i, _ = eager_attn_bwd( + q_i, k_i, v_i, attn_bias, None, ctx.scale, ctx.dropout, outs[i], probs[i], dout_i + ) + + # Rearrange gradients to (s, b, h, d) + dq_i = einops.rearrange(dq_i, 'b s h d -> s b h d') + _dk_i = einops.rearrange(_dk_i, 'b s h d -> s b h d') + _dv_i = einops.rearrange(_dv_i, 'b s h d -> s b h d') + if pg is None: + dk_i = _dk_i + dv_i = _dv_i + else: + # Reduce-scatter gradients if CP > 1 + dk_i = torch.zeros( + (k_i.shape[1] // cp_size, k_i.shape[0], k_i.shape[2], k_i.shape[3]), + device=k_i.device, + dtype=k_i.dtype, + ) + dv_i = torch.zeros( + (v_i.shape[1] // cp_size, v_i.shape[0], v_i.shape[2], v_i.shape[3]), + device=v_i.device, + dtype=v_i.dtype, + ) + torch.distributed.reduce_scatter_tensor(dk_i, _dk_i, group=pg) + torch.distributed.reduce_scatter_tensor(dv_i, _dv_i, group=pg) + + # Collect gradients + dq.append(dq_i) + dk.append(dk_i) + dv.append(dv_i) + + # Concatenate gradients and return + dq = torch.cat(dq, dim=2) + dk = torch.cat(dk, dim=2) + dv = torch.cat(dv, dim=2) + return dq, dk, dv, None, None, None, None diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 6b8209ef6a..df5eb9036b 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -771,6 +771,10 @@ class TransformerConfig(ModelParallelConfig): """Transformer implementation to use. Options are 'transformer_engine' for Transformer Engine and 'local' for MCore.""" + fallback_to_eager_attn: bool = False + """Whether to fallback to eager attention in TE implementation. + Suggested for when desired features are not available in TE implementation.""" + ##################################### # Fine-grained Activation Offloading ##################################### @@ -1813,6 +1817,25 @@ def __post_init__(self): f"the number of layers ({self.num_layers})" ) + if self.fallback_to_eager_attn: + assert self.transformer_impl == "transformer_engine", ( + f"fallback_to_eager_attn is only available with transformer_engine implementation," + f" but got {self.transformer_impl=}." + ) + + if self.fallback_to_eager_attn or self.transformer_impl == "local": + if self.context_parallel_size > 1 and self.cp_comm_type is not None: + all_cp_comm_types_are_all_gather = ( + all(item == "all_gather" for item in self.cp_comm_type) + if isinstance(self.cp_comm_type, list) + else self.cp_comm_type == "all_gather" + ) + if not all_cp_comm_types_are_all_gather: + raise ValueError( + f"fallback_to_eager_attn only supports all_gather communication type " + f"for context parallelism, but got {self.cp_comm_type=} instead." + ) + @dataclass class MLATransformerConfig(TransformerConfig): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 7ecb1e7100..604c8414b5 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1382,6 +1382,9 @@ def _add_transformer_engine_args(parser): group.add_argument('--transformer-impl', default='transformer_engine', choices=['local', 'transformer_engine'], help='Which Transformer implementation to use.') + group.add_argument('--fallback-to-eager-attn', action='store_true', + help='Fallback to eager attention in TE implementation. ' + 'Suggested for when desired features are not available in TE implementation.') group.add_argument('--fp8-param-gather', action='store_true', help='Keep the compute param in fp8 (do not use any other intermediate ' 'dtype) and perform the param all-gather in fp8.') diff --git a/tests/unit_tests/transformer/test_attention.py b/tests/unit_tests/transformer/test_attention.py index 23858937c7..be7e89bf6f 100644 --- a/tests/unit_tests/transformer/test_attention.py +++ b/tests/unit_tests/transformer/test_attention.py @@ -1,20 +1,45 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import copy +from functools import partial +from unittest import mock +import einops import pytest import torch from packaging import version +from torch.nn import functional as F import megatron.core.parallel_state as parallel_state from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.common.embeddings.rope_utils import ( + get_pos_emb_on_this_cp_rank as get_tensor_on_this_cp_rank, +) +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig from megatron.core.transformer.attention import SelfAttention +from megatron.core.transformer.dot_product_attention_context_parallel import ( + AttentionFuncionWithContextParallel, + to_zz_mask_attn_bias, +) from megatron.core.transformer.enums import AttnMaskType from megatron.core.utils import is_te_min_version +from megatron.training.arguments import parse_args +from megatron.training.checkpointing import load_checkpoint, save_checkpoint +from megatron.training.global_vars import set_args +from megatron.training.training import get_model +from megatron.training.utils import unwrap_model +from tests.unit_tests.dist_checkpointing import ( + TempNamedDir, + init_basic_mock_args, + init_checkpointing_mock_args, +) from tests.unit_tests.test_utilities import Utils try: @@ -26,10 +51,19 @@ @pytest.mark.parametrize("output_gate", [False, True]) +@pytest.mark.parametrize( + ("transformer_impl", "fallback_to_eager_attn"), + [("transformer_engine", False), ("transformer_engine", True), ("native", False)], +) class TestParallelAttention: @pytest.fixture(scope='function', autouse=True) - def setup_method(self, output_gate): + def setup_method(self, output_gate, transformer_impl, fallback_to_eager_attn): + if output_gate: + if transformer_impl == "native": + pytest.skip("Native implementation does not support output gate.") + if fallback_to_eager_attn: + pytest.skip("No need to test output gate for fallback_to_eager_attn = True.") Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) self.transformer_config = TransformerConfig( @@ -40,11 +74,18 @@ def setup_method(self, output_gate): bf16=True, params_dtype=torch.bfloat16, attention_output_gate=output_gate, + transformer_impl=transformer_impl, + fallback_to_eager_attn=fallback_to_eager_attn, ) + if transformer_impl == "transformer_engine": + layer_spec = get_gpt_layer_with_transformer_engine_spec( + fallback_to_eager_attn=fallback_to_eager_attn + ) + else: + layer_spec = get_gpt_layer_local_spec() + attn_layer_spec = layer_spec.submodules.self_attention.submodules self.parallel_attention = SelfAttention( - self.transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, - layer_number=1, + self.transformer_config, attn_layer_spec, layer_number=1 ) def teardown_method(self): @@ -55,10 +96,19 @@ def test_constructor(self): assert self.parallel_attention.layer_number == 1 num_weights = sum([p.numel() for p in self.parallel_attention.parameters()]) + + hidden_size = self.transformer_config.hidden_size + standard_num_weights = ( + hidden_size * hidden_size * 4 + hidden_size * 4 # QKVO weight # QKVO bias + ) if self.transformer_config.attention_output_gate: - assert num_weights == 82816 - else: - assert num_weights == 66304 + standard_num_weights += hidden_size * hidden_size + hidden_size # Gate weight and bias + if self.transformer_config.transformer_impl == "transformer_engine": + standard_num_weights += hidden_size * 2 # fused pre layernorm weight and bias + + assert ( + num_weights == standard_num_weights + ), f"{num_weights=} does not match {standard_num_weights=}." def test_cpu_forward(self): # we can't currently do this because the global memory buffer is on GPU @@ -93,6 +143,8 @@ def test_gpu_forward(self): @pytest.mark.parametrize("rotary_interleaved", [True, False]) @pytest.mark.parametrize("fused_qkv_rope", [True, False]) def test_fused_rope_gpu_forward(self, rotary_interleaved, fused_qkv_rope): + if self.transformer_config.fallback_to_eager_attn: + pytest.skip("No need to test fused RoPE for fallback_to_eager_attn = True.") self.parallel_attention.config.apply_rope_fusion = True if rotary_interleaved and not is_te_min_version("2.3.0"): pytest.skip("Only TE >= 2.3.0 supports interleaved fused RoPE.") @@ -166,10 +218,15 @@ def test_checkpointed_gpu_forward(self): @pytest.mark.parametrize("output_gate", [False, True]) +@pytest.mark.parametrize("transformer_impl", ["transformer_engine", "native"]) class TestSelfAttention: @pytest.fixture(scope='function', autouse=True) - def setup_method(self, output_gate): + def setup_method(self, output_gate, transformer_impl): + if transformer_impl == "native": + if output_gate: + pytest.skip("Native implementation does not support output gate.") + self.transformer_impl = transformer_impl self.output_gate = output_gate Utils.destroy_model_parallel() @@ -185,10 +242,15 @@ def run_self_attention(self, pg_collection): attention_output_gate=self.output_gate, tensor_model_parallel_size=tensor_model_parallel_size, use_cpu_initialization=False, + transformer_impl=self.transformer_impl, ) + if self.transformer_impl == "transformer_engine": + get_gpt_layer_spec_fn = get_gpt_layer_with_transformer_engine_spec + else: + get_gpt_layer_spec_fn = get_gpt_layer_local_spec self.self_attention = SelfAttention( self.transformer_config, - get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules, + get_gpt_layer_spec_fn().submodules.self_attention.submodules, layer_number=1, attn_mask_type=AttnMaskType.causal, pg_collection=pg_collection, @@ -261,3 +323,374 @@ def test_self_attention_independent_pg_smoke(self): pg_collection = ProcessGroupCollection(tp=tp_group, cp=cp_group) self.run_self_attention(pg_collection) + + +def _test_parallel_attention_correctness( + transformer_config, + transformer_layer_spec, + tmp_path_dist_ckpt, + atol, + rtol, + tp=1, + sp=False, + cp=1, + seed=123, + sequence_length=256, + micro_batch_size=4, +): + # Model initialization function + def initialize_gpt_model(config, pre_process=True, post_process=True, vp_stage=None): + gpt_model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=128, + max_sequence_length=sequence_length, + pre_process=pre_process, + post_process=post_process, + vp_stage=vp_stage, + ) + return gpt_model + + # Initialize baseline parallel state + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1 + ) + + # Initialize input hidden states + torch.manual_seed(seed) + model_parallel_cuda_manual_seed(seed) + input_hidden_states = ( + torch.rand((sequence_length, micro_batch_size, transformer_config.hidden_size)) + .cuda() + .bfloat16() + .requires_grad_(True) + ) + + with TempNamedDir(tmp_path_dist_ckpt / 'test_parallel_attn', sync=True) as ckpt_dir: + # Set argument + mock_args = parse_args(ignore_unknown_args=True) + set_args(mock_args) + + # Initialize baseline model + init_basic_mock_args(mock_args, 1, 1, bf16=True) + mock_args.context_parallel_size = 1 + mock_args.sequence_parallel = 1 + gpt_model = unwrap_model( + get_model(partial(initialize_gpt_model, config=transformer_config)) + ) + + # Initialize args and save checkpoint + init_checkpointing_mock_args(mock_args, ckpt_dir, False) + mock_args.no_save_optim = True + mock_args.no_save_rng = True + mock_args.no_load_optim = True + mock_args.no_load_rng = True + save_checkpoint(10, gpt_model, None, None, 0) + + # Calculate baseline output + attention = gpt_model[0].decoder.layers[0].self_attention + output_hidden_states_baseline, bias_hidden_states_baseline = attention( + input_hidden_states, attention_mask=None + ) + output_hidden_states_baseline.sum().backward() + + # Save baseline output + input_grad_baseline = input_hidden_states.grad.detach() + output_hidden_states_baseline = output_hidden_states_baseline.detach() + bias_hidden_states_baseline = bias_hidden_states_baseline + if bias_hidden_states_baseline is not None: + bias_hidden_states_baseline = bias_hidden_states_baseline.detach() + has_bias = True + else: + has_bias = False + + # Initialize parallel model + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp, pipeline_model_parallel_size=1, context_parallel_size=cp + ) + torch.manual_seed(seed) + model_parallel_cuda_manual_seed(seed) + transformer_config.context_parallel_size = cp + transformer_config.tensor_model_parallel_size = tp + transformer_config.sequence_parallel = sp + init_basic_mock_args(mock_args, tp, 1, bf16=True) + mock_args.context_parallel_size = cp + mock_args.sequence_parallel = sp + gpt_model = unwrap_model( + get_model(partial(initialize_gpt_model, config=transformer_config)) + ) + with mock.patch('megatron.training.checkpointing.check_checkpoint_args'): + with mock.patch('megatron.training.checkpointing.update_num_microbatches'): + load_checkpoint(gpt_model, None, None) + + # Function to get tensor on this tp and cp rank + cp_group = parallel_state.get_context_parallel_group() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + def get_tensor_on_this_rank(tensor): + if cp > 1: + tensor = get_tensor_on_this_cp_rank(tensor, 0, cp_group) + if tp > 1 and sp: + sp_seg = sequence_length // tp // cp + tensor = tensor[tp_rank * sp_seg : (tp_rank + 1) * sp_seg] + return tensor + + # Calculate parallel model output + input_hidden_states = get_tensor_on_this_rank(input_hidden_states) + input_hidden_states = input_hidden_states.detach().requires_grad_(True) + parallel_attention = gpt_model[0].decoder.layers[0].self_attention + output_hidden_states_parallel, bias_hidden_states_parallel = parallel_attention( + input_hidden_states, attention_mask=None + ) + output_hidden_states_parallel.sum().backward() + input_grad_parallel = input_hidden_states.grad.detach() + + # Check if the output is close + output_hidden_states_baseline = get_tensor_on_this_rank(output_hidden_states_baseline) + input_grad_baseline = get_tensor_on_this_rank(input_grad_baseline) + + assert torch.all( + ~torch.isnan(output_hidden_states_baseline) + ), "output_hidden_states_baseline contains nan" + assert torch.all( + ~torch.isinf(output_hidden_states_baseline) + ), "output_hidden_states_baseline contains inf" + assert torch.all(~torch.isnan(input_grad_baseline)), "input_grad_baseline contains nan" + assert torch.all(~torch.isinf(input_grad_baseline)), "input_grad_baseline contains inf" + assert torch.all( + ~torch.isnan(output_hidden_states_parallel) + ), "output_hidden_states_parallel contains nan" + assert torch.all( + ~torch.isinf(output_hidden_states_parallel) + ), "output_hidden_states_parallel contains inf" + assert torch.all(~torch.isnan(input_grad_parallel)), "input_grad_parallel contains nan" + assert torch.all(~torch.isinf(input_grad_parallel)), "input_grad_parallel contains inf" + if has_bias: + assert torch.all( + ~torch.isnan(bias_hidden_states_baseline) + ), "bias_hidden_states_baseline contains nan" + assert torch.all( + ~torch.isinf(bias_hidden_states_baseline) + ), "bias_hidden_states_baseline contains inf" + assert torch.all( + ~torch.isnan(bias_hidden_states_parallel) + ), "bias_hidden_states_parallel contains nan" + assert torch.all( + ~torch.isinf(bias_hidden_states_parallel) + ), "bias_hidden_states_parallel contains inf" + + torch.testing.assert_close( + output_hidden_states_baseline, + output_hidden_states_parallel, + atol=atol, + rtol=rtol, + msg=lambda msg: f"Mismatch in output_hidden_states: {msg}", + ) + torch.testing.assert_close( + input_grad_baseline, + input_grad_parallel, + atol=atol, + rtol=rtol, + msg=lambda msg: f"Mismatch in input_grad: {msg}", + ) + if has_bias: + torch.testing.assert_close( + bias_hidden_states_baseline, + bias_hidden_states_parallel, + atol=atol, + rtol=rtol, + msg=lambda msg: f"Mismatch in bias_hidden_states: {msg}", + ) + + Utils.destroy_model_parallel() + + +@pytest.mark.parametrize("apply_rope_fusion", [False, True]) +@pytest.mark.parametrize( + ("tp", "sp", "cp"), + [ + (4, False, 1), # TP w/o SP + (4, True, 1), # TP w/ SP + (1, False, 4), # CP + (2, False, 2), # CP + TP w/o SP + (2, True, 2), # CP + TP w/ SP + ], +) +@pytest.mark.parametrize("qk_layernorm", [False, True]) +@pytest.mark.parametrize("fallback_to_eager_attn", [False, True]) +@pytest.mark.parametrize("output_gate", [False, True]) +def test_parallel_attention_correctness( + tmp_path_dist_ckpt, + apply_rope_fusion, + tp, + sp, + cp, + qk_layernorm, + fallback_to_eager_attn, + output_gate, +): + transformer_config = TransformerConfig( + num_layers=1, + hidden_size=128, + num_attention_heads=4, + context_parallel_size=1, + tensor_model_parallel_size=1, + sequence_parallel=False, + bf16=True, + qk_layernorm=qk_layernorm, + apply_rope_fusion=apply_rope_fusion, + attention_output_gate=output_gate, + hidden_dropout=0.0, + attention_dropout=0.0, + ) + + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + fallback_to_eager_attn=fallback_to_eager_attn, + normalization="RMSNorm", + qk_layernorm=qk_layernorm, + ) + if cp > 1: + if qk_layernorm: + atol, rtol = 2e-2, 2e-2 + else: + atol, rtol = 5e-3, 5e-3 + else: + if qk_layernorm: + atol, rtol = 1e-2, 1e-2 + else: + atol, rtol = 2e-3, 2e-3 + + _test_parallel_attention_correctness( + transformer_config, transformer_layer_spec, tmp_path_dist_ckpt, tp, sp, cp + ) + + +def _torch_native_attention(query, key, value, attention_mask, sinks, scaling: float): + """Torch native attention implementation + This was not in the original implementation and slightly affect results; + it prevents overflow in BF16/FP16 when training with batch size > 1 we clamp max values. + """ + # Rearrange query, key, value to (b, h, s, d) + query = einops.rearrange(query, 's b h d -> b h s d') + key = einops.rearrange(key, 's b h d -> b h s d') + value = einops.rearrange(value, 's b h d -> b h s d') + + # Compute attention weights + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + nheads = query.shape[1] + nheads_k = key.shape[1] + heads_k_stride = 1 + mask_bias = to_zz_mask_attn_bias( + attention_mask, 1, nheads, nheads_k, heads_k_stride, query.device, query.dtype + ) + attn_weights = attn_weights + mask_bias + + # Add sinks to attention weights + if sinks is None: + combined_logits = attn_weights + else: + sinks = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # Compute attention scores + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + if sinks is None: + scores = probs + else: + scores = probs[..., :-1] + + # Compute attention output + attn_output = torch.matmul(scores, value) + attn_output = einops.rearrange(attn_output, 'b h s d -> s b h d') + attn_output = attn_output.contiguous() + return attn_output + + +def test_eager_attention_function_correctness(): + """Test the correctness of the context parallel eager attention function""" + + # Configuration + batch_size = 4 + num_heads = 2 + head_dim = 256 + seq_len_q = 512 + seq_len_k = 2048 + scale = 1 / (head_dim**2) + + # Initialize inputs + q = torch.rand( + (seq_len_q, batch_size, num_heads, head_dim), + device='cuda', + dtype=torch.bfloat16, + requires_grad=True, + ) + k = torch.rand( + (seq_len_k, batch_size, num_heads, head_dim), + device='cuda', + dtype=torch.bfloat16, + requires_grad=True, + ) + v = torch.rand( + (seq_len_k, batch_size, num_heads, head_dim), + device='cuda', + dtype=torch.bfloat16, + requires_grad=True, + ) + + def randbool(shape, **kwargs): + return torch.randn(shape, **kwargs) > 0 + + attn_bias = randbool((batch_size, 1, seq_len_q, seq_len_k), device='cuda') + sinks = None + + # Torch native attention forward and backward pass + out_torch = _torch_native_attention( + query=q, key=k, value=v, attention_mask=attn_bias, sinks=sinks, scaling=scale + ) + loss_torch = out_torch.sum() + loss_torch.backward() + torch_q_grad = q.grad.clone() + torch_k_grad = k.grad.clone() + torch_v_grad = v.grad.clone() + q.grad.zero_() + k.grad.zero_() + v.grad.zero_() + if sinks is not None: + torch_sinks_grad = sinks.grad.clone() + sinks.grad.zero_() + else: + torch_sinks_grad = None + + # Custom attention forward and backward pass + out_custom = AttentionFuncionWithContextParallel.apply( + q, k, v, attn_bias, 0.0, scale, None # dropout + ) + loss_custom = out_custom.sum() + loss_custom.backward() + custom_q_grad = q.grad.clone() + custom_k_grad = k.grad.clone() + custom_v_grad = v.grad.clone() + q.grad.zero_() + k.grad.zero_() + v.grad.zero_() + if sinks is not None: + custom_sinks_grad = sinks.grad.clone() + sinks.grad.zero_() + else: + custom_sinks_grad = None + + # Check attention output and gradients + assert torch.equal(out_custom, out_torch), "Mismatch in attention output" + tol = {"atol": 1e-4, "rtol": 1e-4} + for tensor_name, tensor_torch, tensor_custom in [ + ("q_grad", torch_q_grad, custom_q_grad), + ("k_grad", torch_k_grad, custom_k_grad), + ("v_grad", torch_v_grad, custom_v_grad), + ("sinks_grad", torch_sinks_grad, custom_sinks_grad), + ]: + if (tensor_torch is not None) and (tensor_custom is not None): + torch.testing.assert_close( + out_custom, out_torch, **tol, msg=lambda msg: f"Mismatch in {tensor_name}: {msg}" + )