Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gpt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _get_transformer_layer_spec(use_te, config):
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
qk_l2_norm=args.qk_l2_norm,
use_kitchen=config.use_kitchen,
fallback_to_eager_attn=config.fallback_to_eager_attn,
)
else:
return get_gpt_layer_local_spec(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import warnings
from typing import Optional, Tuple
Expand All @@ -17,6 +17,7 @@
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.models.backends import BackendSpecProvider
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from megatron.core.utils import get_te_version, is_te_min_version
Expand All @@ -25,6 +26,10 @@
class TESpecProvider(BackendSpecProvider):
"""A protocol for providing the submodules used in Spec building."""

def __init__(self, fallback_to_eager_attn: bool = False):
super().__init__()
self.fallback_to_eager_attn = fallback_to_eager_attn

def linear(self) -> type:
"""Which linear module TE backend uses"""
return TELinear
Expand Down Expand Up @@ -56,6 +61,8 @@ def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> type:

def core_attention(self) -> type:
"""Which module to use for attention"""
if self.fallback_to_eager_attn:
return DotProductAttention
return TEDotProductAttention

def grouped_mlp_modules(
Expand Down
23 changes: 17 additions & 6 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType, LayerType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
Expand Down Expand Up @@ -85,6 +86,7 @@ def get_gpt_layer_with_transformer_engine_spec(
use_te_op_fuser: Optional[bool] = False,
use_kitchen: bool = False,
use_te_activation_func: bool = False,
fallback_to_eager_attn: bool = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).

Expand Down Expand Up @@ -116,13 +118,15 @@ def get_gpt_layer_with_transformer_engine_spec(

if use_kitchen:
assert HAVE_KITCHEN
backend: BackendSpecProvider = KitchenSpecProvider(fallback=TESpecProvider())
backend: BackendSpecProvider = KitchenSpecProvider(
fallback=TESpecProvider(fallback_to_eager_attn=fallback_to_eager_attn)
)
if use_te_op_fuser:
raise AssertionError("use_te_op_fuser not compatible with using kitchen in mlp.")
if use_te_activation_func:
raise AssertionError("use_te_activation_func not compatible with using kitchen.")
else:
backend = TESpecProvider()
backend = TESpecProvider(fallback_to_eager_attn=fallback_to_eager_attn)
Copy link
Contributor

@hxbai hxbai Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to handle this in get_attention_module_spec_for_backend rather than modify the TESpecProvider since we have other backends like Kitchen.

Similar to the code here

module = TEFusedMLP if use_te_op_fuser else MLP


sharded_state_dict_keys_map = {}

Expand All @@ -135,6 +139,7 @@ def get_gpt_layer_with_transformer_engine_spec(
multi_latent_attention=multi_latent_attention,
mla_down_proj_use_column_parallel=False,
normalization=normalization,
fallback_to_eager_attn=fallback_to_eager_attn,
)

mlp = get_mlp_module_spec_for_backend(
Expand Down Expand Up @@ -214,6 +219,7 @@ def get_gpt_layer_local_spec(
multi_latent_attention=multi_latent_attention,
mla_down_proj_use_column_parallel=True,
normalization=normalization,
fallback_to_eager_attn=False,
)

mlp = get_mlp_module_spec_for_backend(
Expand Down Expand Up @@ -278,6 +284,7 @@ def get_attention_module_spec_for_backend(
multi_latent_attention: Optional[bool] = False,
mla_down_proj_use_column_parallel: Optional[bool] = False,
normalization: Optional[str] = None,
fallback_to_eager_attn: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for Attention"""

Expand All @@ -292,6 +299,7 @@ def get_attention_module_spec_for_backend(
rms_norm = normalization == "RMSNorm"
qk_norm = backend.layer_norm(rms_norm=rms_norm, for_qk=True)

core_attention = backend.core_attention() if not fallback_to_eager_attn else DotProductAttention
if multi_latent_attention:
assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA."
linear_q_down_proj = (
Expand Down Expand Up @@ -328,7 +336,7 @@ def get_attention_module_spec_for_backend(
linear_q_up_proj=linear_q_up_proj,
linear_kv_down_proj=linear_kv_down_proj,
linear_kv_up_proj=linear_kv_up_proj,
core_attention=backend.core_attention(),
core_attention=core_attention,
linear_proj=backend.row_parallel_linear(),
q_layernorm=qk_norm,
kv_layernorm=qk_norm,
Expand All @@ -352,7 +360,7 @@ def get_attention_module_spec_for_backend(
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=linear_qkv,
core_attention=backend.core_attention(),
core_attention=core_attention,
linear_proj=backend.row_parallel_linear(),
q_layernorm=qk_norm,
k_layernorm=qk_norm,
Expand Down Expand Up @@ -499,6 +507,7 @@ def get_gpt_decoder_layer_specs(
if use_transformer_engine:
layer_norm_impl = TENorm
get_layer_spec_kwargs["use_te_activation_func"] = config.use_te_activation_func
get_layer_spec_kwargs['fallback_to_eager_attn'] = config.fallback_to_eager_attn
get_layer_spec_fn = get_gpt_layer_with_transformer_engine_spec
else:
layer_norm_impl = LNImpl
Expand Down Expand Up @@ -652,9 +661,11 @@ def get_gpt_mtp_block_spec(
"""GPT Multi-Token Prediction (MTP) block spec."""
if use_transformer_engine:
backend: BackendSpecProvider = (
KitchenSpecProvider(fallback=TESpecProvider())
KitchenSpecProvider(
fallback=TESpecProvider(fallback_to_eager_attn=config.fallback_to_eager_attn)
)
if config.use_kitchen
else TESpecProvider()
else TESpecProvider(fallback_to_eager_attn=config.fallback_to_eager_attn)
)
else:
backend = (
Expand Down
27 changes: 23 additions & 4 deletions megatron/core/transformer/dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.


import math
Expand All @@ -12,6 +12,9 @@
from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.dot_product_attention_context_parallel import (
AttentionFuncionWithContextParallel,
)
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -54,9 +57,12 @@ def __init__(

self.config: TransformerConfig = config

assert (
self.config.context_parallel_size == 1
), "Context parallelism is only supported by TEDotProductAttention!"
if self.config.context_parallel_size > 1:
assert attention_dropout is None and self.config.attention_dropout == 0.0, (
f'DotProductAttention with context parallelism does not support attention dropout,'
f' but got {self.config.context_parallel_size=},'
f' {attention_dropout=}, and {self.config.attention_dropout=}.'
)

self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
Expand Down Expand Up @@ -172,6 +178,19 @@ def forward(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)

if self.config.context_parallel_size > 1:
output = AttentionFuncionWithContextParallel.apply(
query,
key,
value,
attention_mask,
self.config.attention_dropout,
self.softmax_scale,
parallel_state.get_context_parallel_group(),
)
output = output.view(query.shape[0], query.shape[1], self.hidden_size_per_partition)
return output

# [b, np, sq, sk]
output_size = (query.size(1), query.size(2), query.size(0), key.size(0))

Expand Down
Loading