You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Unfortunately, when lora is enabled, the model training fails dismally - all settings remain the same, other than .quantize and enable_lora, but the loss increases and accuracy is extremely low (around 1% as opposed to 46% without lora).
Hi @ashep29, would you mind sharing the outputs only after quantize?
Dynamic int8 quantization might fail if the value range of the weights is too large. Imagine trying to narrow an large range (such as -1e5 to 1e5) into an 8-bit capacity.
Most modern models should be okay and have similar performance after quantize, such as the Gemma and Llama models, but GPT2 is untested during my implementation.
Ref:
The following code trains fine - loss decreases, accuracy improves:
chosen_preset = "gpt2_base_en"
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
chosen_preset,
sequence_length=128,
)
gpt2_causal_lm = keras_nlp.models.GPT2CausalLM.from_preset(
chosen_preset, preprocessor=preprocessor
)
learning_rate = keras.optimizers.schedules.PolynomialDecay(
5e-5,
decay_steps=train_ds.cardinality() * num_epochs,
end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimiser = keras.optimizers.Adam(learning_rate)
metrics = [keras.metrics.SparseCategoricalAccuracy()]
gpt2_causal_lm.compile(
optimizer=optimiser,
loss=loss,
weighted_metrics=metrics,
)
gpt2_causal_lm.fit(train_ds, epochs=num_epochs)
Unfortunately, when lora is enabled, the model training fails dismally - all settings remain the same, other than .quantize and enable_lora, but the loss increases and accuracy is extremely low (around 1% as opposed to 46% without lora).
gpt2_causal_lm_with_qlora = keras_nlp.models.GPT2CausalLM.from_preset(
chosen_preset, preprocessor=preprocessor
)
gpt2_causal_lm_with_qlora.quantize("int8")
gpt2_causal_lm_with_qlora.backbone.enable_lora(rank=4)
The model produces junk output, e.g.;
This
12
L
12
This
12
[26262626
12
[12.
[12.
[[.
[[. the r)
[2
Is there anything wrong with how I'm using lora or is it simply not implemented or available for use with GPT2?
The text was updated successfully, but these errors were encountered: