You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
class GemmaEncoder(keras.Layer):
def __init__(
self,
preprocessor: keras_hub.models.Gemma3CausalLMPreprocessor,
backbone: keras_hub.models.Gemma3Backbone,
pooling_layer: keras.layers.Layer,
**kwargs):
super().__init__(**kwargs)
self.preprocessor = preprocessor
self.backbone = backbone
self.pooling_layer = pooling_layer
@classmethod
def from_preset(
cls,
preset: str = "gemma3_1b",
pooling_layer: keras.layers.Layer = None,
name = "gemma_encoder",
**kwargs):
preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(preset, sequence_length = 128)
backbone = keras_hub.models.Gemma3Backbone.from_preset(preset)
pooling_layer = keras.layers.GlobalAveragePooling1D(name = name + "_global_average_pooling1d") if pooling_layer is None else pooling_layer
return cls(preprocessor = preprocessor, backbone = backbone, pooling_layer = pooling_layer, name = name, **kwargs)
def call(self, inputs):
adapted = inputs if isinstance(inputs, dict) and "prompts" in inputs else \
{
"prompts": keras.ops.array(inputs),
"responses": keras.ops.array([""])
}
tokenized = self.preprocessor(adapted)
embedded = self.backbone.token_embedding(tokenized[0]["token_ids"])
pooled = self.pooling_layer(embedded)
return pooled
gse_layer = GemmaEncoder.from_preset(preset = "gemma3_1b");
gse_layer(inputs = ["oranges and lemons are sour", "lemons and oranges are tart"])
headline_input = keras.layers.Input(shape = (), dtype = "string", name = "headline")
headline_featurizer = gse_layer(headline_input)
dense_16 = keras.layers.Dense(16, activation = "relu", name = "dense_16")(headline_featurizer)
activation = keras.layers.Dense(1, activation = "sigmoid", name = "activation")(dense_16)
inputs = [headline_input]
outputs = [activation]
nn_model = keras.Model(inputs = inputs, outputs = outputs, name = "nn_model")
optimizer = keras.optimizers.Adam(learning_rate=0.001) # keras.optimizers.Nadam(learning_rate = 0.00007)
nn_model.compile(optimizer = optimizer, loss = "binary_crossentropy", metrics = ["accuracy"], run_eagerly = True)
x_train = {"headline" : keras.ops.array(["hello", "goodbye", "see you soon"])}
y_train = keras.ops.array([[1], [0], [0]])
nn_model_history = nn_model.fit(
x = x_train,
y = y_train,
# batch_size = 1,
epochs = 3,
verbose = 1)
Expected behavior
The kernel shouldn't die.
Additional context
This code is a variation on another open issue I have that uses a Gemma (not Gemma3) model. In that case, the Gemma-based model trains without crashing but has some concerning warnings and doesn't work when deployed to an endpoint. In this case, with the Gemma3-based model, it crashes immediately after training begins.
Would you like to help us fix it?
I'm happy to provide any information I can to assist with fixing the issue, but I suspect it's a bug in KerasHub Gemma3 code.
The text was updated successfully, but these errors were encountered:
adapted = inputs if isinstance(inputs, dict) and "prompts" in inputs else \
{
"prompts": keras.ops.array(inputs),
"responses": keras.ops.array([""]*len(inputs))
}
using responses as an array of empty strings of same length as prompts is solving the problem, for now, and I'm getting segmentation fault in tokenizer when calling tensorflow apis., when using "responses": keras.ops.array([""]).
Describe the bug
When training a classification model that uses the Gemma3 token_embedding layer, the kernel dies.
To Reproduce
https://colab.research.google.com/drive/12BAorKsFy_1651K7LLKPbglG0Pe951pI?usp=sharing
Here is the relevant code:
Expected behavior
The kernel shouldn't die.
Additional context
This code is a variation on another open issue I have that uses a Gemma (not Gemma3) model. In that case, the Gemma-based model trains without crashing but has some concerning warnings and doesn't work when deployed to an endpoint. In this case, with the Gemma3-based model, it crashes immediately after training begins.
Would you like to help us fix it?
I'm happy to provide any information I can to assist with fixing the issue, but I suspect it's a bug in KerasHub Gemma3 code.
The text was updated successfully, but these errors were encountered: