@@ -532,39 +532,23 @@ def __init__(self, config, layer_number,
532532 config .num_attention_heads , world_size )
533533
534534 # Per GQA head and per partition values
535- if self .use_gqa :
536- kv_projection_size = config .kv_channels * config .num_key_value_heads
537- self .num_key_value_heads_per_partition = core .utils .divide (
538- config .num_key_value_heads , world_size )
539- self .num_key_value_groups = core .utils .divide (
540- config .num_attention_heads , config .num_key_value_heads )
541- assert self .hidden_size_per_attention_head == core .utils .divide (
542- kv_projection_size , config .num_key_value_heads )
535+ self .num_key_value_heads_per_partition = core .utils .divide (
536+ config .num_key_value_heads , world_size )
537+ self .num_key_value_groups = core .utils .divide (
538+ config .num_attention_heads , config .num_key_value_heads )
539+ kv_projection_size = config .kv_channels * config .num_key_value_heads
540+ assert self .hidden_size_per_attention_head == core .utils .divide (
541+ kv_projection_size , config .num_key_value_heads )
543542
544543 # Strided linear layer.
545- if attention_type == AttnType .self_attn and not self . use_gqa :
544+ if attention_type == AttnType .self_attn :
546545 self .query_key_value = tensor_parallel .ColumnParallelLinear (
547546 config .hidden_size ,
548- 3 * projection_size ,
547+ projection_size + 2 * kv_projection_size ,
549548 config = config ,
550549 init_method = config .init_method ,
551550 bias = args .add_bias_linear ,
552551 gather_output = False )
553- elif attention_type == AttnType .self_attn and self .use_gqa :
554- self .query = tensor_parallel .ColumnParallelLinear (
555- config .hidden_size ,
556- projection_size ,
557- config = config ,
558- init_method = config .init_method ,
559- bias = config .add_bias_linear ,
560- gather_output = False )
561- self .key_value = tensor_parallel .ColumnParallelLinear (
562- config .hidden_size ,
563- 2 * kv_projection_size ,
564- config = config ,
565- init_method = config .init_method ,
566- bias = config .add_bias_linear ,
567- gather_output = False )
568552 else :
569553 assert attention_type == AttnType .cross_attn
570554 self .query = tensor_parallel .ColumnParallelLinear (
@@ -657,6 +641,13 @@ def repeat_kv(self, hidden_states, n_rep):
657641 return hidden_states .reshape (slen , batch ,
658642 num_key_value_heads_per_partition * n_rep ,
659643 head_dim )
644+
645+ def split_tensor (self , mixed_x_layer ):
646+ query_layer = mixed_x_layer [:, :, :, :- 2 , :].reshape (mixed_x_layer .shape [:- 1 ] + (- 1 , self .hidden_size_per_attention_head ))
647+ key_layer = mixed_x_layer [:, :, :, - 2 , :]
648+ value_layer = mixed_x_layer [:, :, :, - 1 , :]
649+
650+ return query_layer , key_layer , value_layer
660651
661652 def forward (self , hidden_states , attention_mask ,
662653 encoder_output = None , inference_params = None ,
@@ -686,45 +677,26 @@ def forward(self, hidden_states, attention_mask,
686677 # Query, Key, and Value
687678 # =====================
688679
689- if self .attention_type == AttnType .self_attn and not self . use_gqa :
690- # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
680+ if self .attention_type == AttnType .self_attn :
681+ # Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
691682 mixed_x_layer , _ = self .query_key_value (hidden_states )
692683
693- # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
684+ # [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn]
694685 new_tensor_shape = mixed_x_layer .size ()[:- 1 ] + \
695- (self .num_attention_heads_per_partition ,
696- 3 * self .hidden_size_per_attention_head )
686+ (- 1 , ( self .num_key_value_groups + 2 ) ,
687+ self .hidden_size_per_attention_head )
697688 mixed_x_layer = mixed_x_layer .view (* new_tensor_shape )
698689
699- # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
700- (query_layer ,
690+ # [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn]
691+ (query_layer
701692 key_layer ,
702- value_layer ) = tensor_parallel .split_tensor_along_last_dim (mixed_x_layer , 3 )
703- elif self .attention_type == AttnType .self_attn and self .use_gqa :
704- # Attention head [sq, b, h] --> [sq, b, hp]
705- query_layer , _ = self .query (hidden_states )
706- # [sq, b, hp] --> [sq, b, np, hn]
707- new_tensor_shape = query_layer .size ()[:- 1 ] + \
708- (self .num_attention_heads_per_partition ,
709- self .hidden_size_per_attention_head )
710- query_layer = query_layer .view (* new_tensor_shape )
711-
712- # Attention heads [sq, b, h] --> [sq, b, (np * 2 * hn)]
713- mixed_kv_layer , _ = self .key_value (hidden_states )
714- # [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
715- new_tensor_shape = mixed_kv_layer .size ()[:- 1 ] + \
716- (self .num_key_value_heads_per_partition ,
717- 2 * self .hidden_size_per_attention_head )
718- mixed_kv_layer = mixed_kv_layer .view (* new_tensor_shape )
719- # [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn]
720- (key_layer ,
721- value_layer ) = tensor_parallel .split_tensor_along_last_dim (
722- mixed_kv_layer , 2 )
693+ value_layer ) = self .split_tensor (mixed_x_layer )
723694
724695 # Repeat kv
725- key_layer = self .repeat_kv (key_layer , self .num_key_value_groups )
726- value_layer = self .repeat_kv (value_layer ,
727- self .num_key_value_groups )
696+ if self .use_gqa :
697+ key_layer = self .repeat_kv (key_layer , self .num_key_value_groups )
698+ value_layer = self .repeat_kv (value_layer ,
699+ self .num_key_value_groups )
728700 else :
729701 assert not self .use_gqa , 'GQA + cross-attn not tested yet'
730702
0 commit comments