From 612bf9137641c4d5eec241576c5ea6cb45adf16c Mon Sep 17 00:00:00 2001 From: cyita Date: Mon, 11 Nov 2024 14:55:57 +0800 Subject: [PATCH 1/5] merge qkv --- .../transformers/npu_models/mp_models_base.py | 76 +++++++++++-------- .../transformers/npu_models/qwen2_mp.py | 50 +++++++++--- 2 files changed, 86 insertions(+), 40 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 73666487333..2b84201f5a2 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -141,38 +141,54 @@ def attention(self, if self.n_splits_linear != 1: hidden_states = self.unsqueeze(hidden_states, axis=0) - query_states = self.linear( - hidden_states, - num_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - n_splits=self.n_splits_linear, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill") - ) + if mode == "prefill": + concat_linear = self.linear(hidden_states, + num_key_value_heads * head_dim * 2 + num_heads * head_dim, + hidden_size, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill")) + query_states = self.simple_slice(concat_linear, begin=[0, 0, 0], + end=[1, seq_len, num_heads * head_dim]) + key_states = self.simple_slice(concat_linear, begin=[0, 0, num_heads * head_dim], + end=[1, seq_len, num_heads * head_dim + num_key_value_heads * head_dim]) + value_states = self.simple_slice(concat_linear, + begin=[0, 0, num_heads * head_dim + num_key_value_heads * head_dim], + end=[1, seq_len, num_heads * head_dim + num_key_value_heads * head_dim * 2]) + else: + query_states = self.linear( + hidden_states, + num_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) - key_states = self.linear( - hidden_states, - num_key_value_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - n_splits=self.n_splits_linear, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill") - ) + key_states = self.linear( + hidden_states, + num_key_value_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) - value_states = self.linear( - hidden_states, - num_key_value_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - n_splits=self.n_splits_linear, - scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill") - ) + value_states = self.linear( + hidden_states, + num_key_value_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + n_splits=self.n_splits_linear, + scale_factor=(self.group_size == 0), + is_prefill=(mode == "prefill") + ) if q_bias is not None: query_states = query_states + q_bias diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index ab11c27b665..217b2d00f60 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -241,6 +241,11 @@ def __init__( else: self.compile() print(f"{mode} end compiling") + qwen_size = "7b" if self.hidden_size == 3584 else "1.5b" + xml_path = f"gw/qwen-{qwen_size}-npu-qkv-{mode}-{num_layers}-{n_splits_linear}-{n_splits_down_proj}.xml" + + if not os.path.exists(xml_path): + self.save(xml_path) def build_decoder( self, @@ -815,16 +820,41 @@ def run_prefill( mlp_layer = curr_layer.mlp weights = [] - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, - mlp_layer.down_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if n_splits_linear == 1: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + else: + qkv_weights = [] + qkv_scales = [] + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + qkv_weights.append(torch.stack(l_weights, axis=0)) + qkv_scales.append(torch.stack(scales, axis=0)) + + weights.append((torch.cat(qkv_weights, dim=1), torch.cat(qkv_scales, dim=1))) + + for layer_list in [attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) From fb7c23e113211c1fe3e47e945faf336788cc0c6d Mon Sep 17 00:00:00 2001 From: cyita Date: Mon, 11 Nov 2024 19:49:31 +0800 Subject: [PATCH 2/5] variadic split --- .../transformers/npu_models/mp_models_base.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 2b84201f5a2..f3cac8c2ec2 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -149,13 +149,14 @@ def attention(self, n_splits=self.n_splits_linear, scale_factor=(self.group_size == 0), is_prefill=(mode == "prefill")) - query_states = self.simple_slice(concat_linear, begin=[0, 0, 0], - end=[1, seq_len, num_heads * head_dim]) - key_states = self.simple_slice(concat_linear, begin=[0, 0, num_heads * head_dim], - end=[1, seq_len, num_heads * head_dim + num_key_value_heads * head_dim]) - value_states = self.simple_slice(concat_linear, - begin=[0, 0, num_heads * head_dim + num_key_value_heads * head_dim], - end=[1, seq_len, num_heads * head_dim + num_key_value_heads * head_dim * 2]) + query_states, key_states, value_states = self.variadic_split(concat_linear, 2, [1536, 256, 256]) + # query_states = self.simple_slice(concat_linear, begin=[0, 0, 0], + # end=[1, seq_len, num_heads * head_dim]) + # key_states = self.simple_slice(concat_linear, begin=[0, 0, num_heads * head_dim], + # end=[1, seq_len, num_heads * head_dim + num_key_value_heads * head_dim]) + # value_states = self.simple_slice(concat_linear, + # begin=[0, 0, num_heads * head_dim + num_key_value_heads * head_dim], + # end=[1, seq_len, num_heads * head_dim + num_key_value_heads * head_dim * 2]) else: query_states = self.linear( hidden_states, From 744976ab310d5aad16e10a5f691f5ba66274e62c Mon Sep 17 00:00:00 2001 From: cyita Date: Tue, 12 Nov 2024 11:34:50 +0800 Subject: [PATCH 3/5] temp --- .../transformers/npu_models/mp_models_base.py | 13 +++--- .../transformers/npu_models/qwen2_mp.py | 40 +++++++++++++------ 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index f3cac8c2ec2..8cdd61d2de5 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -149,6 +149,7 @@ def attention(self, n_splits=self.n_splits_linear, scale_factor=(self.group_size == 0), is_prefill=(mode == "prefill")) + concat_linear = concat_linear + q_bias query_states, key_states, value_states = self.variadic_split(concat_linear, 2, [1536, 256, 256]) # query_states = self.simple_slice(concat_linear, begin=[0, 0, 0], # end=[1, seq_len, num_heads * head_dim]) @@ -191,12 +192,12 @@ def attention(self, is_prefill=(mode == "prefill") ) - if q_bias is not None: - query_states = query_states + q_bias - if k_bias is not None: - key_states = key_states + k_bias - if v_bias is not None: - value_states = value_states + v_bias + if q_bias is not None: + query_states = query_states + q_bias + if k_bias is not None: + key_states = key_states + k_bias + if v_bias is not None: + value_states = value_states + v_bias query_states = self.reshape( query_states, [1, seq_len, num_heads, head_dim] diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 217b2d00f60..a1e0b58873c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -173,13 +173,18 @@ def __init__( post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights] if q_biases is None: - q_biases = [] - k_biases = [] - v_biases = [] - for i in range(num_layers): - q_biases.append(self.create_input_op((self.num_heads * self.head_dim,))) - k_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,))) - v_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,))) + if mode == "prefill": + q_biases = [] + for i in range(num_layers): + q_biases.append(self.create_input_op((self.num_heads * self.head_dim + self.num_key_value_heads * self.head_dim * 2,))) + else: + q_biases = [] + k_biases = [] + v_biases = [] + for i in range(num_layers): + q_biases.append(self.create_input_op((self.num_heads * self.head_dim,))) + k_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,))) + v_biases.append(self.create_input_op((self.num_key_value_heads * self.head_dim,))) else: q_biases = [self.constant(w) for w in q_biases] k_biases = [self.constant(w) for w in k_biases] @@ -217,8 +222,8 @@ def __init__( input_layernorm_weight=input_layernorm_weights[i], post_attention_layernorm_weight=post_attn_layernorm_weights[i], q_bias=q_biases[i], - k_bias=k_biases[i], - v_bias=v_biases[i], + k_bias=k_biases[i] if mode == "decode" else None, + v_bias=v_biases[i] if mode == "decode" else None, past_key=past_keys[i], past_value=past_values[i], ) @@ -530,7 +535,9 @@ def forward( attention_mask.to(torch.int64), position_ids.to(torch.int64)) inputs += (self.layer_norm_0, self.layer_norm_1) - inputs += (self.q_bias, self.k_bias, self.v_bias) + inputs += (self.layer_norm_0, self.layer_norm_1, self.q_bias) + # inputs += (self.q_bias, self.k_bias, self.v_bias) + # inputs += (self.q_bias) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 ) @@ -862,6 +869,10 @@ def run_prefill( layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + merge_bias = torch.cat([attn_layer.q_proj_dq_list.q_proj_dq_0.bias, + attn_layer.k_proj_dq_list.k_proj_dq_0.bias, + attn_layer.v_proj_dq_list.v_proj_dq_0.bias]).to(torch.float16) + new_decoderlayer = FusedQwenLowBitDecoderlayer( weights, num_heads=num_heads, @@ -870,9 +881,12 @@ def run_prefill( cached_sin=cached_sin, layer_norm_0=layer_norm_0, layer_norm_1=layer_norm_1, - q_bias=attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16), - k_bias=attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16), - v_bias=attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16), + # q_bias=attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16), + # k_bias=attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16), + # v_bias=attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16), + q_bias=merge_bias, + k_bias=None, + v_bias=None, layer_idx=layer_idx, rms_norm_eps=rms_norm_eps, intermediate_size=intermediate_size, From f9b9a943081a57f75b43a8301beb9ec9f3a49917 Mon Sep 17 00:00:00 2001 From: cyita Date: Tue, 12 Nov 2024 15:10:27 +0800 Subject: [PATCH 4/5] fix error --- .../ipex_llm/transformers/npu_models/mp_models_base.py | 8 ++++++-- .../llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py | 5 +---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 8cdd61d2de5..1e976b8ec4d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -149,8 +149,12 @@ def attention(self, n_splits=self.n_splits_linear, scale_factor=(self.group_size == 0), is_prefill=(mode == "prefill")) - concat_linear = concat_linear + q_bias - query_states, key_states, value_states = self.variadic_split(concat_linear, 2, [1536, 256, 256]) + if q_bias is not None: + concat_linear = concat_linear + q_bias + query_states, key_states, value_states = self.variadic_split( + concat_linear, 2, + [num_heads * head_dim, num_key_value_heads * head_dim, num_key_value_heads * head_dim] + ) # query_states = self.simple_slice(concat_linear, begin=[0, 0, 0], # end=[1, seq_len, num_heads * head_dim]) # key_states = self.simple_slice(concat_linear, begin=[0, 0, num_heads * head_dim], diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index a1e0b58873c..4d8afe6ee4c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -247,7 +247,7 @@ def __init__( self.compile() print(f"{mode} end compiling") qwen_size = "7b" if self.hidden_size == 3584 else "1.5b" - xml_path = f"gw/qwen-{qwen_size}-npu-qkv-{mode}-{num_layers}-{n_splits_linear}-{n_splits_down_proj}.xml" + xml_path = f"gw/qwen-{qwen_size}-npu-qkv-split-{mode}-{num_layers}-{n_splits_linear}-{n_splits_down_proj}.xml" if not os.path.exists(xml_path): self.save(xml_path) @@ -534,10 +534,7 @@ def forward( inputs = (hidden_states.to(torch.float16), attention_mask.to(torch.int64), position_ids.to(torch.int64)) - inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.layer_norm_0, self.layer_norm_1, self.q_bias) - # inputs += (self.q_bias, self.k_bias, self.v_bias) - # inputs += (self.q_bias) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 ) From a6117d2101adb2b30a6447fca5dd6605a82387fc Mon Sep 17 00:00:00 2001 From: cyita Date: Fri, 15 Nov 2024 15:51:45 +0800 Subject: [PATCH 5/5] remove --- .../src/ipex_llm/transformers/npu_models/mp_models_base.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 1e976b8ec4d..28de549fe4d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -155,13 +155,6 @@ def attention(self, concat_linear, 2, [num_heads * head_dim, num_key_value_heads * head_dim, num_key_value_heads * head_dim] ) - # query_states = self.simple_slice(concat_linear, begin=[0, 0, 0], - # end=[1, seq_len, num_heads * head_dim]) - # key_states = self.simple_slice(concat_linear, begin=[0, 0, num_heads * head_dim], - # end=[1, seq_len, num_heads * head_dim + num_key_value_heads * head_dim]) - # value_states = self.simple_slice(concat_linear, - # begin=[0, 0, num_heads * head_dim + num_key_value_heads * head_dim], - # end=[1, seq_len, num_heads * head_dim + num_key_value_heads * head_dim * 2]) else: query_states = self.linear( hidden_states,