-
Couldn't load subscription status.
- Fork 12
Add support for grouped-query attention #9
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -758,6 +758,7 @@ def __init__( | |
| gradient_accumulation_fusion=False, | ||
| normalize_attention_scores=True, | ||
| transfer_with_static_ring=True, | ||
| num_query_groups=None, | ||
| ): | ||
| super(ParallelAttention, self).__init__() | ||
|
|
||
|
|
@@ -772,20 +773,29 @@ def __init__( | |
|
|
||
| self.set_accepted_adapter_types([InfusedAdapterConfig._target_]) | ||
|
|
||
| if num_query_groups is None: | ||
| num_query_groups = num_attention_heads | ||
|
|
||
| if kv_channels is None: | ||
| assert ( | ||
| hidden_size % num_attention_heads == 0 | ||
| ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' | ||
| kv_channels = hidden_size // num_attention_heads | ||
| projection_size = kv_channels * num_attention_heads | ||
| query_projection_size = kv_channels * num_attention_heads | ||
| kv_projection_size = kv_channels * num_query_groups | ||
|
|
||
| # Per attention head and per partition values. | ||
| world_size = parallel_state.get_tensor_model_parallel_world_size() | ||
| self.hidden_size_per_attention_head = safe_divide(projection_size, num_attention_heads) | ||
| self.hidden_size_per_attention_head = safe_divide(query_projection_size, num_attention_heads) | ||
| self.num_attention_heads_per_partition = safe_divide(num_attention_heads, world_size) | ||
| self.num_attention_heads_partition_offset = ( | ||
| self.num_attention_heads_per_partition * parallel_state.get_tensor_model_parallel_rank() | ||
| ) | ||
| self.hidden_size_per_key_value_head = safe_divide(kv_projection_size, num_query_groups) | ||
| self.num_query_groups_per_partition = safe_divide(num_query_groups, world_size) | ||
| self.num_query_groups_partition_offset = ( | ||
| self.num_query_groups_per_partition * parallel_state.get_tensor_model_parallel_rank() | ||
| ) | ||
|
|
||
| no_async_tensor_model_parallel_allreduce = ( | ||
| parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel | ||
|
|
@@ -795,7 +805,7 @@ def __init__( | |
| if attention_type == AttnType.self_attn: | ||
| self.query_key_value = tensor_parallel.ColumnParallelLinear( | ||
| hidden_size, | ||
| 3 * projection_size, | ||
| query_projection_size + 2 * kv_projection_size, | ||
| gather_output=False, | ||
| init_method=init_method, | ||
| use_cpu_initialization=use_cpu_initialization, | ||
|
|
@@ -807,9 +817,12 @@ def __init__( | |
| ) | ||
| else: | ||
| assert attention_type == AttnType.cross_attn | ||
| if num_query_groups != num_attention_heads: | ||
| raise ValueError('Grouped-query attention is not currently supported in cross attention.') | ||
| assert query_projection_size == kv_projection_size | ||
| self.query = tensor_parallel.ColumnParallelLinear( | ||
| hidden_size, | ||
| projection_size, | ||
| query_projection_size, | ||
| gather_output=False, | ||
| init_method=init_method, | ||
| bias=bias, | ||
|
|
@@ -821,7 +834,7 @@ def __init__( | |
|
|
||
| self.key_value = tensor_parallel.ColumnParallelLinear( | ||
| hidden_size, | ||
| 2 * projection_size, | ||
| 2 * kv_projection_size, | ||
| gather_output=False, | ||
| init_method=init_method, | ||
| bias=bias, | ||
|
|
@@ -850,7 +863,7 @@ def __init__( | |
|
|
||
| # Output. | ||
| self.dense = tensor_parallel.RowParallelLinear( | ||
| projection_size, | ||
| query_projection_size, | ||
| hidden_size, | ||
| input_is_parallel=True, | ||
| init_method=output_layer_init_method, | ||
|
|
@@ -941,7 +954,7 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size, dtype): | |
| return torch.empty( | ||
| inference_max_sequence_len, | ||
| batch_size, | ||
| self.num_attention_heads_per_partition, | ||
| self.num_query_groups_per_partition, | ||
| self.hidden_size_per_attention_head, | ||
| dtype=dtype, | ||
| device=torch.cuda.current_device(), | ||
|
|
@@ -1025,28 +1038,45 @@ def forward( | |
| # ===================== | ||
|
|
||
| if self.attention_type == AttnType.self_attn: | ||
| # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] | ||
| # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] | ||
| mixed_x_layer, _ = self.query_key_value(hidden_states) | ||
|
|
||
| # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] | ||
| # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2) * hn] | ||
| new_tensor_shape = mixed_x_layer.size()[:-1] + ( | ||
| self.num_attention_heads_per_partition, | ||
| 3 * self.hidden_size_per_attention_head, | ||
| self.num_query_groups_per_partition, | ||
| ( | ||
| (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) | ||
| * self.hidden_size_per_attention_head | ||
| ), | ||
| ) | ||
| if self.megatron_legacy: | ||
| mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) | ||
| mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) | ||
|
|
||
| # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] | ||
| (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3) | ||
| # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] | ||
| (query_layer, key_layer, value_layer) = torch.split( | ||
| mixed_x_layer, | ||
| [ | ||
| ( | ||
| self.num_attention_heads_per_partition | ||
| // self.num_query_groups_per_partition | ||
| * self.hidden_size_per_attention_head | ||
| ), | ||
| self.hidden_size_per_attention_head, | ||
| self.hidden_size_per_attention_head, | ||
| ], | ||
| dim=3, | ||
| ) | ||
| # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] | ||
| query_layer = query_layer.reshape(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) | ||
| else: | ||
| # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] | ||
| mixed_kv_layer, _ = self.key_value(encoder_output) | ||
|
|
||
| # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] | ||
| new_tensor_shape = mixed_kv_layer.size()[:-1] + ( | ||
| self.num_attention_heads_per_partition, | ||
| 2 * self.hidden_size_per_attention_head, | ||
| self.num_query_groups_per_partition, | ||
| 2 * self.hidden_size_per_key_value_head, | ||
| ) | ||
| if self.megatron_legacy: | ||
| mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True) | ||
|
|
@@ -1116,6 +1146,19 @@ def forward( | |
| if get_key_value: | ||
| present = (key_layer, value_layer) | ||
|
|
||
| # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] | ||
| # This is a noop for normal attention where ng == np. When using grouped-query attention this | ||
| # creates a view that has the keys and values virtually repeated along their dimension to | ||
| # match the number of queries. | ||
| key_layer = key_layer.repeat_interleave( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How is this repeat_intervleaving done? Does it use an explicit torch.view() operation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @amithrm GQA does not have a one-to-one correspondence between query heads and key/value heads like MHA. The middle and right illustrations in Figure 2 will be transformed to have the same shape as the left one through this operation.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't this cause calculations to be duplicated/redundant across GQA groups? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This operation does not increase the time complexity. |
||
| self.num_attention_heads_per_partition // self.num_query_groups_per_partition, | ||
| dim = 2 | ||
| ) | ||
| value_layer = value_layer.repeat_interleave( | ||
| self.num_attention_heads_per_partition // self.num_query_groups_per_partition, | ||
| dim = 2 | ||
| ) | ||
|
|
||
| if checkpoint_core_attention: | ||
| context_layer = self._checkpointed_attention_forward( | ||
| query_layer, | ||
|
|
@@ -1377,6 +1420,7 @@ def __init__( | |
| num_moe_experts=1, | ||
| moe_frequency=1, | ||
| moe_dropout=0.0, | ||
| num_query_groups=None, | ||
| ): | ||
| super(ParallelTransformerLayer_, self).__init__() | ||
|
|
||
|
|
@@ -1459,6 +1503,7 @@ def __init__( | |
| gradient_accumulation_fusion=gradient_accumulation_fusion, | ||
| normalize_attention_scores=normalize_attention_scores, | ||
| transfer_with_static_ring=transfer_with_static_ring, | ||
| num_query_groups=num_query_groups, | ||
| ) | ||
|
|
||
| if transformer_block_type == 'normformer': | ||
|
|
@@ -1921,6 +1966,7 @@ def __init__( | |
| num_moe_experts=1, | ||
| moe_frequency=1, | ||
| moe_dropout=0.0, | ||
| num_query_groups=None, | ||
| ): | ||
| super(ParallelTransformerLayer, self).__init__( | ||
| init_method=init_method, | ||
|
|
@@ -1962,6 +2008,7 @@ def __init__( | |
| num_moe_experts=num_moe_experts, | ||
| moe_frequency=moe_frequency, | ||
| moe_dropout=moe_dropout, | ||
| num_query_groups=num_query_groups, | ||
| ) | ||
|
|
||
| if precision == 32: | ||
|
|
@@ -2184,6 +2231,7 @@ def __init__( | |
| num_moe_experts=1, | ||
| moe_frequency=1, | ||
| moe_dropout=0.0, | ||
| num_query_groups=None, | ||
| ): | ||
| super(ParallelTransformer, self).__init__() | ||
|
|
||
|
|
@@ -2295,6 +2343,8 @@ def build_layer(layer_number): | |
| lt = layer_type | ||
|
|
||
| if self.transformer_engine: | ||
| assert num_query_groups is None, 'num_query_groups currently not supported with transformer engine' | ||
|
|
||
| return AutocastTransformerLayer( | ||
| hidden_size=hidden_size, | ||
| ffn_hidden_size=ffn_hidden_size, | ||
|
|
@@ -2359,7 +2409,8 @@ def build_layer(layer_number): | |
| num_moe_experts=num_moe_experts, | ||
| moe_frequency=moe_frequency, | ||
| moe_dropout=moe_dropout, | ||
| position_embedding_type=self.position_embedding_type | ||
| position_embedding_type=self.position_embedding_type, | ||
| num_query_groups=num_query_groups, | ||
| ) | ||
|
|
||
| if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: | ||
|
|
||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you post performance numbers and convergence curves for pretraining with your changes? Can you use below config
tensor_parallel:8
pipeline_parallel:8
data_parallel:1
global_batch_size:256
activation_checkpointing:full
precision:bf16+SR
dataset:bookcorpus
lrscheduler: cosineannealing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have confirmed that the number of seconds per step is reduced on 7B with a smaller num_query_groups setting than the normal setting.
We plan to start training on 70B next week.
We will share the results through our contact at AWS Japan.
In preparation, we plan to do some experiments with smaller settings.
I will share the results as soon as we are done.