Skip to content

Crash on Gemma3 token_embedding Layer During Training #2205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rlcauvin opened this issue Apr 5, 2025 · 1 comment
Open

Crash on Gemma3 token_embedding Layer During Training #2205

rlcauvin opened this issue Apr 5, 2025 · 1 comment
Assignees
Labels
Gemma Gemma model specific issues keras-team-review-pending type:Bug Something isn't working

Comments

@rlcauvin
Copy link

rlcauvin commented Apr 5, 2025

Describe the bug
When training a classification model that uses the Gemma3 token_embedding layer, the kernel dies.

To Reproduce
https://colab.research.google.com/drive/12BAorKsFy_1651K7LLKPbglG0Pe951pI?usp=sharing

Here is the relevant code:

class GemmaEncoder(keras.Layer):

  def __init__(
    self,
    preprocessor: keras_hub.models.Gemma3CausalLMPreprocessor,
    backbone: keras_hub.models.Gemma3Backbone,
    pooling_layer: keras.layers.Layer,
    **kwargs):

    super().__init__(**kwargs)

    self.preprocessor = preprocessor
    self.backbone = backbone
    self.pooling_layer = pooling_layer

  @classmethod
  def from_preset(
    cls,
    preset: str = "gemma3_1b",
    pooling_layer: keras.layers.Layer = None,
    name = "gemma_encoder",
    **kwargs):

    preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(preset, sequence_length = 128)
    backbone = keras_hub.models.Gemma3Backbone.from_preset(preset)
    pooling_layer = keras.layers.GlobalAveragePooling1D(name = name + "_global_average_pooling1d") if pooling_layer is None else pooling_layer

    return cls(preprocessor = preprocessor, backbone = backbone, pooling_layer = pooling_layer, name = name, **kwargs)

  def call(self, inputs):

    adapted = inputs if isinstance(inputs, dict) and "prompts" in inputs else \
      {
      "prompts": keras.ops.array(inputs),
      "responses": keras.ops.array([""])
      }
    tokenized = self.preprocessor(adapted)
    embedded = self.backbone.token_embedding(tokenized[0]["token_ids"])
    pooled = self.pooling_layer(embedded)

    return pooled

gse_layer = GemmaEncoder.from_preset(preset = "gemma3_1b");

gse_layer(inputs = ["oranges and lemons are sour", "lemons and oranges are tart"])

headline_input = keras.layers.Input(shape = (), dtype = "string", name = "headline")
headline_featurizer = gse_layer(headline_input)
dense_16 = keras.layers.Dense(16, activation = "relu", name = "dense_16")(headline_featurizer)
activation = keras.layers.Dense(1, activation = "sigmoid", name = "activation")(dense_16)

inputs = [headline_input]
outputs = [activation]
nn_model = keras.Model(inputs = inputs, outputs = outputs, name = "nn_model")

optimizer = keras.optimizers.Adam(learning_rate=0.001) # keras.optimizers.Nadam(learning_rate = 0.00007)
nn_model.compile(optimizer = optimizer, loss = "binary_crossentropy", metrics = ["accuracy"], run_eagerly = True)

x_train = {"headline" : keras.ops.array(["hello", "goodbye", "see you soon"])}
y_train = keras.ops.array([[1], [0], [0]])

nn_model_history = nn_model.fit(
  x = x_train,
  y = y_train,
  # batch_size = 1,
  epochs = 3,
  verbose = 1)

Expected behavior
The kernel shouldn't die.

Additional context
This code is a variation on another open issue I have that uses a Gemma (not Gemma3) model. In that case, the Gemma-based model trains without crashing but has some concerning warnings and doesn't work when deployed to an endpoint. In this case, with the Gemma3-based model, it crashes immediately after training begins.

Would you like to help us fix it?
I'm happy to provide any information I can to assist with fixing the issue, but I suspect it's a bug in KerasHub Gemma3 code.

@pctablet505
Copy link
Collaborator

adapted = inputs if isinstance(inputs, dict) and "prompts" in inputs else \
      {
        "prompts": keras.ops.array(inputs),
        "responses": keras.ops.array([""]*len(inputs))
      }

using responses as an array of empty strings of same length as prompts is solving the problem, for now, and I'm getting segmentation fault in tokenizer when calling tensorflow apis., when using "responses": keras.ops.array([""]).

There is some issue with tokenization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues keras-team-review-pending type:Bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants