diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index ccf6e242d90..42ef077a965 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -141,45 +141,60 @@ def attention(self, if self.n_splits_linear != 1: hidden_states = self.unsqueeze(hidden_states, axis=0) - query_states = self.linear( - hidden_states, - num_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - n_splits=self.n_splits_linear, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill") - ) + if mode == "prefill": + concat_linear = self.linear(hidden_states, + num_key_value_heads * head_dim * 2 + num_heads * head_dim, + hidden_size, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill")) + if q_bias is not None: + concat_linear = concat_linear + q_bias + query_states, key_states, value_states = self.variadic_split( + concat_linear, 2, + [num_heads * head_dim, num_key_value_heads * head_dim, num_key_value_heads * head_dim] + ) + else: + query_states = self.linear( + hidden_states, + num_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) - key_states = self.linear( - hidden_states, - num_key_value_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - n_splits=self.n_splits_linear, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill") - ) + key_states = self.linear( + hidden_states, + num_key_value_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) - value_states = self.linear( - hidden_states, - num_key_value_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - n_splits=self.n_splits_linear, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill") - ) + value_states = self.linear( + hidden_states, + num_key_value_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) - if q_bias is not None: - query_states = query_states + q_bias - if k_bias is not None: - key_states = key_states + k_bias - if v_bias is not None: - value_states = value_states + v_bias + if q_bias is not None: + query_states = query_states + q_bias + if k_bias is not None: + key_states = key_states + k_bias + if v_bias is not None: + value_states = value_states + v_bias query_states = self.reshape( query_states, [1, seq_len, num_heads, head_dim] diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 015efe10031..b2039c40fe2 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -173,13 +173,18 @@ def __init__( post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights] if q_biases is None: - q_biases = [] - k_biases = [] - v_biases = [] - for i in range(num_layers): - q_biases.append(self.create_input_op((self.num_heads * self.head_dim,))) - k_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,))) - v_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,))) + if mode == "prefill": + q_biases = [] + for i in range(num_layers): + q_biases.append(self.create_input_op((self.num_heads * self.head_dim + self.num_key_value_heads * self.head_dim * 2,))) + else: + q_biases = [] + k_biases = [] + v_biases = [] + for i in range(num_layers): + q_biases.append(self.create_input_op((self.num_heads * self.head_dim,))) + k_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,))) + v_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,))) else: q_biases = [self.constant(w) for w in q_biases] k_biases = [self.constant(w) for w in k_biases] @@ -217,8 +222,8 @@ def __init__( input_layernorm_weight=input_layernorm_weights[i], post_attention_layernorm_weight=post_attn_layernorm_weights[i], q_bias=q_biases[i], - k_bias=k_biases[i], - v_bias=v_biases[i], + k_bias=k_biases[i] if mode == "decode" else None, + v_bias=v_biases[i] if mode == "decode" else None, past_key=past_keys[i], past_value=past_values[i], ) @@ -241,6 +246,11 @@ def __init__( else: self.compile() print(f"{mode} end compiling") + qwen_size = "7b" if self.hidden_size == 3584 else "1.5b" + xml_path = f"gw/qwen-{qwen_size}-npu-qkv-split-{mode}-{num_layers}-{n_splits_linear}-{n_splits_down_proj}.xml" + + if not os.path.exists(xml_path): + self.save(xml_path) def build_decoder( self, @@ -524,8 +534,7 @@ def forward( inputs = (hidden_states.to(torch.float16), attention_mask.to(torch.float16), position_ids.to(torch.int64)) - inputs += (self.layer_norm_0, self.layer_norm_1) - inputs += (self.q_bias, self.k_bias, self.v_bias) + inputs += (self.layer_norm_0, self.layer_norm_1, self.q_bias) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 ) @@ -815,16 +824,41 @@ def run_prefill( mlp_layer = curr_layer.mlp weights = [] - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, - mlp_layer.down_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if n_splits_linear == 1: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + else: + qkv_weights = [] + qkv_scales = [] + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + qkv_weights.append(torch.stack(l_weights, axis=0)) + qkv_scales.append(torch.stack(scales, axis=0)) + + weights.append((torch.cat(qkv_weights, dim=1), torch.cat(qkv_scales, dim=1))) + + for layer_list in [attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -832,6 +866,10 @@ def run_prefill( layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + merge_bias = torch.cat([attn_layer.q_proj_dq_list.q_proj_dq_0.bias, + attn_layer.k_proj_dq_list.k_proj_dq_0.bias, + attn_layer.v_proj_dq_list.v_proj_dq_0.bias]).to(torch.float16) + new_decoderlayer = FusedQwenLowBitDecoderlayer( weights, num_heads=num_heads, @@ -840,9 +878,12 @@ def run_prefill( cached_sin=cached_sin, layer_norm_0=layer_norm_0, layer_norm_1=layer_norm_1, - q_bias=attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16), - k_bias=attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16), - v_bias=attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16), + # q_bias=attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16), + # k_bias=attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16), + # v_bias=attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16), + q_bias=merge_bias, + k_bias=None, + v_bias=None, layer_idx=layer_idx, rms_norm_eps=rms_norm_eps, intermediate_size=intermediate_size,