Skip to content
This repository was archived by the owner on Oct 14, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ model:
hidden_size: 768
ffn_hidden_size: 3072 # Transformer FFN hidden size. For Llama it's 8/3*hidden_size
num_attention_heads: 12
num_query_groups: 12
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
use_scaled_init_method: True # use scaled residuals initialization
hidden_dropout: 0 # Dropout probability for hidden state transformer.
Expand Down
1 change: 1 addition & 0 deletions nemo/examples/nlp/language_modeling/llama_13b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export TP=8
export PP=4
export N_LAYERS=40
export N_AH=40
export N_QG=40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you post performance numbers and convergence curves for pretraining with your changes? Can you use below config

tensor_parallel:8
pipeline_parallel:8
data_parallel:1
global_batch_size:256
activation_checkpointing:full
precision:bf16+SR
dataset:bookcorpus
lrscheduler: cosineannealing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have confirmed that the number of seconds per step is reduced on 7B with a smaller num_query_groups setting than the normal setting.

We plan to start training on 70B next week.
We will share the results through our contact at AWS Japan.

In preparation, we plan to do some experiments with smaller settings.
I will share the results as soon as we are done.

export FFN_HS=13824
export GBS=1024
export UBS=1
Expand Down
1 change: 1 addition & 0 deletions nemo/examples/nlp/language_modeling/llama_7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export TP=8
export PP=1
export N_LAYERS=32
export N_AH=32
export N_QG=32
export FFN_HS=11008
export GBS=256

Expand Down
4 changes: 3 additions & 1 deletion nemo/examples/nlp/language_modeling/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ fi
: ${PP:=1}
: ${N_LAYERS:=32}
: ${N_AH:=32}
: ${N_QG:=32}
: ${UBS:=1}
: ${FFN_HS:=11008}
: ${GBS:=256}
echo "SEQ_LEN=$SEQ_LENGTH, HS=$HS, FFN_HS=$FFN_HS TP=$TP PP=$PP N_LAYERS=$N_LAYERS N_AH=$N_AH GBS=$GBS UBS=$UBS"
echo "SEQ_LEN=$SEQ_LENGTH, HS=$HS, FFN_HS=$FFN_HS TP=$TP PP=$PP N_LAYERS=$N_LAYERS N_AH=$N_AH N_QG=$N_QG GBS=$GBS UBS=$UBS"

LOG_PATH=logs/$SLURM_JOB_ID/$NODEID/
mkdir -p $LOG_PATH
Expand Down Expand Up @@ -100,6 +101,7 @@ $MAYBE_COMPILE torchrun $DISTRIBUTED_ARGS megatron_gpt_pretraining.py \
model.ffn_hidden_size=$FFN_HS \
model.num_layers=$N_LAYERS \
model.num_attention_heads=$N_AH \
model.num_query_groups=$N_QG \
model.init_method_std=0.021 \
model.hidden_dropout=0 \
model.layernorm_epsilon=1e-5 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ fi
: ${PP:=1}
: ${N_LAYERS:=32}
: ${N_AH:=32}
: ${N_QG:=32}
: ${UBS:=1}
: ${FFN_HS:=11008}
: ${GBS:=256}
echo "SEQ_LEN=$SEQ_LENGTH, HS=$HS, FFN_HS=$FFN_HS TP=$TP PP=$PP N_LAYERS=$N_LAYERS N_AH=$N_AH GBS=$GBS UBS=$UBS"
echo "SEQ_LEN=$SEQ_LENGTH, HS=$HS, FFN_HS=$FFN_HS TP=$TP PP=$PP N_LAYERS=$N_LAYERS N_AH=$N_AH N_QG=$N_QG GBS=$GBS UBS=$UBS"

LOG_PATH=logs/$SLURM_JOB_ID/$NODEID/
mkdir -p $LOG_PATH
Expand Down Expand Up @@ -102,6 +103,7 @@ $MAYBE_COMPILE torchrun $DISTRIBUTED_ARGS megatron_gpt_pretraining.py \
model.ffn_hidden_size=$FFN_HS \
model.num_layers=$N_LAYERS \
model.num_attention_heads=$N_AH \
model.num_query_groups=$N_QG \
model.init_method_std=0.02 \
model.hidden_dropout=0 \
model.layernorm_epsilon=1e-6 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
use_emha=False,
multi_query_attention=False,
save_logits=False,
num_query_groups=None,
):

super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)
Expand Down Expand Up @@ -249,6 +250,7 @@ def __init__(
reduce_amax=reduce_amax,
use_emha=use_emha,
multi_query_attention=multi_query_attention,
num_query_groups=num_query_groups,
)

if self.share_embeddings_and_output_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def model_provider_func(self, pre_process, post_process):
fp8_amax_compute_algo=self.cfg.get('fp8_amax_compute_algo', 'most_recent'),
use_emha=self.cfg.get('use_emha', False),
save_logits=self.cfg.get('save_logits', False),
num_query_groups=self.cfg.get('num_query_groups', None),
)

return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def get_language_model(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
num_query_groups=None,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -173,6 +174,7 @@ def get_language_model(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
num_query_groups=num_query_groups,
)
# key used for checkpoints.
language_model_key = 'language_model'
Expand Down Expand Up @@ -472,6 +474,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
num_query_groups=None,
):
super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -573,7 +576,8 @@ def __init__(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
position_embedding_type=self.position_embedding_type
position_embedding_type=self.position_embedding_type,
num_query_groups=num_query_groups,
)
self._encoder_key = 'encoder'

Expand Down Expand Up @@ -613,6 +617,7 @@ def __init__(
activations_checkpoint_granularity=activations_checkpoint_granularity,
activations_checkpoint_layers_per_pipeline=activations_checkpoint_layers_per_pipeline,
transformer_engine=transformer_engine,
num_query_groups=num_query_groups,
)
self._decoder_key = 'decoder'

Expand Down
83 changes: 67 additions & 16 deletions nemo/nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ def __init__(
gradient_accumulation_fusion=False,
normalize_attention_scores=True,
transfer_with_static_ring=True,
num_query_groups=None,
):
super(ParallelAttention, self).__init__()

Expand All @@ -772,20 +773,29 @@ def __init__(

self.set_accepted_adapter_types([InfusedAdapterConfig._target_])

if num_query_groups is None:
num_query_groups = num_attention_heads

if kv_channels is None:
assert (
hidden_size % num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = hidden_size // num_attention_heads
projection_size = kv_channels * num_attention_heads
query_projection_size = kv_channels * num_attention_heads
kv_projection_size = kv_channels * num_query_groups

# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = safe_divide(projection_size, num_attention_heads)
self.hidden_size_per_attention_head = safe_divide(query_projection_size, num_attention_heads)
self.num_attention_heads_per_partition = safe_divide(num_attention_heads, world_size)
self.num_attention_heads_partition_offset = (
self.num_attention_heads_per_partition * parallel_state.get_tensor_model_parallel_rank()
)
self.hidden_size_per_key_value_head = safe_divide(kv_projection_size, num_query_groups)
self.num_query_groups_per_partition = safe_divide(num_query_groups, world_size)
self.num_query_groups_partition_offset = (
self.num_query_groups_per_partition * parallel_state.get_tensor_model_parallel_rank()
)

no_async_tensor_model_parallel_allreduce = (
parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel
Expand All @@ -795,7 +805,7 @@ def __init__(
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
hidden_size,
3 * projection_size,
query_projection_size + 2 * kv_projection_size,
gather_output=False,
init_method=init_method,
use_cpu_initialization=use_cpu_initialization,
Expand All @@ -807,9 +817,12 @@ def __init__(
)
else:
assert attention_type == AttnType.cross_attn
if num_query_groups != num_attention_heads:
raise ValueError('Grouped-query attention is not currently supported in cross attention.')
assert query_projection_size == kv_projection_size
self.query = tensor_parallel.ColumnParallelLinear(
hidden_size,
projection_size,
query_projection_size,
gather_output=False,
init_method=init_method,
bias=bias,
Expand All @@ -821,7 +834,7 @@ def __init__(

self.key_value = tensor_parallel.ColumnParallelLinear(
hidden_size,
2 * projection_size,
2 * kv_projection_size,
gather_output=False,
init_method=init_method,
bias=bias,
Expand Down Expand Up @@ -850,7 +863,7 @@ def __init__(

# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
query_projection_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
Expand Down Expand Up @@ -941,7 +954,7 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size, dtype):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.num_query_groups_per_partition,
self.hidden_size_per_attention_head,
dtype=dtype,
device=torch.cuda.current_device(),
Expand Down Expand Up @@ -1025,28 +1038,45 @@ def forward(
# =====================

if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
mixed_x_layer, _ = self.query_key_value(hidden_states)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
# [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
self.num_query_groups_per_partition,
(
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head
),
)
if self.megatron_legacy:
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query_layer, key_layer, value_layer) = torch.split(
mixed_x_layer,
[
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
],
dim=3,
)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query_layer = query_layer.reshape(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)

# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
self.num_query_groups_per_partition,
2 * self.hidden_size_per_key_value_head,
)
if self.megatron_legacy:
mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
Expand Down Expand Up @@ -1116,6 +1146,19 @@ def forward(
if get_key_value:
present = (key_layer, value_layer)

# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
# This is a noop for normal attention where ng == np. When using grouped-query attention this
# creates a view that has the keys and values virtually repeated along their dimension to
# match the number of queries.
key_layer = key_layer.repeat_interleave(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this repeat_intervleaving done? Does it use an explicit torch.view() operation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amithrm
Instead of reshaping a tensor like torch.view, it repeats elements of a tensor.

GQA does not have a one-to-one correspondence between query heads and key/value heads like MHA.
Instead, multiple query heads share a single key/value head.
By virtually repeating shared key/value heads until the number of heads becomes num_attention_heads, core_attention can treat MHA and GQA equivalently.

The middle and right illustrations in Figure 2 will be transformed to have the same shape as the left one through this operation.

254449231-2808a8a9-2c5c-4c72-b7cf-bbdb4468832b

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this cause calculations to be duplicated/redundant across GQA groups?

Copy link
Author

@yasuhisa-nakashima yasuhisa-nakashima Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation does not increase the time complexity.
GQA reduces the complexity by reducing the output dimension of the projection layer that transforms hidden states into key/value heads.
The complexity of the dot product in the core attention is equivalent for both GQA and MHA.

self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)
value_layer = value_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)

if checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer,
Expand Down Expand Up @@ -1377,6 +1420,7 @@ def __init__(
num_moe_experts=1,
moe_frequency=1,
moe_dropout=0.0,
num_query_groups=None,
):
super(ParallelTransformerLayer_, self).__init__()

Expand Down Expand Up @@ -1459,6 +1503,7 @@ def __init__(
gradient_accumulation_fusion=gradient_accumulation_fusion,
normalize_attention_scores=normalize_attention_scores,
transfer_with_static_ring=transfer_with_static_ring,
num_query_groups=num_query_groups,
)

if transformer_block_type == 'normformer':
Expand Down Expand Up @@ -1921,6 +1966,7 @@ def __init__(
num_moe_experts=1,
moe_frequency=1,
moe_dropout=0.0,
num_query_groups=None,
):
super(ParallelTransformerLayer, self).__init__(
init_method=init_method,
Expand Down Expand Up @@ -1962,6 +2008,7 @@ def __init__(
num_moe_experts=num_moe_experts,
moe_frequency=moe_frequency,
moe_dropout=moe_dropout,
num_query_groups=num_query_groups,
)

if precision == 32:
Expand Down Expand Up @@ -2184,6 +2231,7 @@ def __init__(
num_moe_experts=1,
moe_frequency=1,
moe_dropout=0.0,
num_query_groups=None,
):
super(ParallelTransformer, self).__init__()

Expand Down Expand Up @@ -2295,6 +2343,8 @@ def build_layer(layer_number):
lt = layer_type

if self.transformer_engine:
assert num_query_groups is None, 'num_query_groups currently not supported with transformer engine'

return AutocastTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
Expand Down Expand Up @@ -2359,7 +2409,8 @@ def build_layer(layer_number):
num_moe_experts=num_moe_experts,
moe_frequency=moe_frequency,
moe_dropout=moe_dropout,
position_embedding_type=self.position_embedding_type
position_embedding_type=self.position_embedding_type,
num_query_groups=num_query_groups,
)

if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Expand Down