@@ -78,51 +78,68 @@ def weight_process(name, quant_config, lora_config, state_dict, device):
78
78
raise ValueError (f"quant_config.weight_quantize_algo { quant_config .weight_quantize_algo } is not supported." )
79
79
80
80
81
+ def get_mixer (mixer , mixer_num , index = 0 ):
82
+ if index == mixer_num - 1 :
83
+ return mixer [index ]
84
+ else :
85
+ return mixer [index ] @ get_mixer (mixer , mixer_num , index + 1 )
86
+
87
+
81
88
def lora_process (name , layer , lora_config , state_dict , device , lora_state_dict = None ):
89
+
82
90
target_device = device if device == "cpu" else device + ":0"
83
91
84
92
if (name + ".weight" ) not in state_dict .keys ():
85
93
return
86
94
87
95
weight = state_dict .pop (name + ".weight" )
88
96
lora_use_mixer = lora_config .lora_use_mixer
97
+
98
+ mixer_num = lora_config .mixer_num
99
+ mixer = {}
89
100
use_mora = lora_config .use_mora
101
+
90
102
if lora_state_dict is None :
91
103
lora_A = state_dict .pop (name + ".lora_A" )
92
104
if not use_mora :
93
105
lora_B = state_dict .pop (name + ".lora_B" )
94
106
if lora_use_mixer :
95
- lora_AB = state_dict .pop (name + ".lora_AB" )
107
+ for i in range (mixer_num ):
108
+ mixer [i ] = state_dict .pop (name + ".lora_mixer_" + str (i ))
96
109
else :
97
110
lora_A = lora_state_dict .pop (name + ".lora_A" )
98
111
if not use_mora :
99
112
lora_B = lora_state_dict .pop (name + ".lora_B" )
100
113
if lora_use_mixer :
101
- lora_AB = lora_state_dict .pop (name + ".lora_AB" )
114
+ for i in range (mixer_num ):
115
+ mixer [i ] = state_dict .pop (name + ".lora_mixer_" + str (i ))
102
116
if device != "cpu" :
103
117
weight = weight .to (target_device )
104
118
lora_A = lora_A .to (target_device )
105
119
if not use_mora :
106
120
lora_B = lora_B .to (target_device )
107
121
if lora_use_mixer :
108
- lora_AB = lora_AB .to (target_device )
122
+ for key in mixer .keys ():
123
+ mixer [key ] = mixer [key ].to (target_device )
109
124
110
125
if device == "cpu" and weight .dtype .name == "BF16" :
111
126
weight = weight .astype ("float32" )
112
127
lora_A = lora_A .astype ("float32" )
113
128
if not use_mora :
114
129
lora_B = lora_B .astype ("float32" )
130
+
115
131
if lora_use_mixer :
116
- lora_AB = lora_AB .astype (lora_config .dtype )
117
- delta_weight = layer .get_delta_weight (lora_A , lora_B , lora_AB )
132
+ for key in mixer .keys ():
133
+ mixer [key ] = mixer [key ].astype (lora_config .dtype )
134
+ delta_weight = layer .get_delta_weight (lora_A , lora_B , get_mixer (mixer , mixer_num ))
118
135
elif use_mora :
119
136
delta_weight = layer .get_delta_weight (lora_A )
120
137
else :
121
138
delta_weight = layer .get_delta_weight (lora_A , lora_B )
122
139
out = (weight + delta_weight ).astype (lora_config .dtype )
123
140
else :
124
141
if lora_use_mixer :
125
- delta_weight = layer .get_delta_weight (lora_A , lora_B , lora_AB )
142
+ delta_weight = layer .get_delta_weight (lora_A , lora_B , get_mixer ( mixer , mixer_num ) )
126
143
elif use_mora :
127
144
delta_weight = layer .get_delta_weight (lora_A )
128
145
else :
0 commit comments