Skip to content

GPT2 Model performance dismal with enable_lora #2162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ashep29 opened this issue Mar 23, 2025 · 2 comments
Open

GPT2 Model performance dismal with enable_lora #2162

ashep29 opened this issue Mar 23, 2025 · 2 comments

Comments

@ashep29
Copy link

ashep29 commented Mar 23, 2025

The following code trains fine - loss decreases, accuracy improves:

chosen_preset = "gpt2_base_en"

preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
chosen_preset,
sequence_length=128,
)

gpt2_causal_lm = keras_nlp.models.GPT2CausalLM.from_preset(
chosen_preset, preprocessor=preprocessor
)

learning_rate = keras.optimizers.schedules.PolynomialDecay(
5e-5,
decay_steps=train_ds.cardinality() * num_epochs,
end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimiser = keras.optimizers.Adam(learning_rate)
metrics = [keras.metrics.SparseCategoricalAccuracy()]

gpt2_causal_lm.compile(
optimizer=optimiser,
loss=loss,
weighted_metrics=metrics,
)
gpt2_causal_lm.fit(train_ds, epochs=num_epochs)


Unfortunately, when lora is enabled, the model training fails dismally - all settings remain the same, other than .quantize and enable_lora, but the loss increases and accuracy is extremely low (around 1% as opposed to 46% without lora).

gpt2_causal_lm_with_qlora = keras_nlp.models.GPT2CausalLM.from_preset(
chosen_preset, preprocessor=preprocessor
)

gpt2_causal_lm_with_qlora.quantize("int8")
gpt2_causal_lm_with_qlora.backbone.enable_lora(rank=4)

The model produces junk output, e.g.;

This

12

L

12

This

12

[26262626

12

[12.

[12.

[[.

[[. the r)

[2

Is there anything wrong with how I'm using lora or is it simply not implemented or available for use with GPT2?

@ashep29
Copy link
Author

ashep29 commented Mar 31, 2025

Note: the same model without qlora works nicely, e.g.

Epoch 1/3
ptxas spam warnings
2/960 ━━━━━━━━━━━━━━━━━━━━ 1:15 79ms/step - loss: 0.6675 - sparse_categorical_accuracy: 0.1776
ptxas spam warnings
959/960 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 0.6554 - sparse_categorical_accuracy: 0.3852
ptxas spam warnings
960/960 ━━━━━━━━━━━━━━━━━━━━ 172s 107ms/step - loss: 0.6554 - sparse_categorical_accuracy: 0.3853
Epoch 2/3
960/960 ━━━━━━━━━━━━━━━━━━━━ 60s 62ms/step - loss: 0.5610 - sparse_categorical_accuracy: 0.4341
Epoch 3/3
960/960 ━━━━━━━━━━━━━━━━━━━━ 61s 62ms/step - loss: 0.5279 - sparse_categorical_accuracy: 0.4554

<keras.src.callbacks.history.History at 0x7ff928290100>

output_finetuned = gpt2_causal_lm.generate(sentence_starters,max_length=max_length)

for sentence in output_finetuned:
print(f"--{sentence}\n")

--I bought a pair of crocs for my daughter, and they were very good.

--Crocs are so comfortable and convenient, and I have worn them to work, and I am very happy with them.

--Received a pair of crocs in white, and I love them! I'm a fan since I started wearing them.

@james77777778
Copy link
Collaborator

james77777778 commented Apr 21, 2025

Hi @ashep29, would you mind sharing the outputs only after quantize?

Dynamic int8 quantization might fail if the value range of the weights is too large. Imagine trying to narrow an large range (such as -1e5 to 1e5) into an 8-bit capacity.

Most modern models should be okay and have similar performance after quantize, such as the Gemma and Llama models, but GPT2 is untested during my implementation.
Ref:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants