Skip to content

Add ESM model #2177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pass-lin opened this issue Mar 30, 2025 · 3 comments
Open

Add ESM model #2177

pass-lin opened this issue Mar 30, 2025 · 3 comments
Assignees
Labels
type:feature New feature or request

Comments

@pass-lin
Copy link
Contributor

This is a pre-trained model in bioinformatics.
from hf document and paper

At present keras_hub do not seem to have a bioinformatics model, would you welcome my submission of ESM as the first bio-pre-trained model?

This model is similar to BERT, so the overall implementation structure will be similar to BERT. In addition to BERT, we will add an hf weight transformation model.
Although this model supports legacy tf.keras in hf, it might be better for our weight conversion script to only support torch weights

@pass-lin
Copy link
Contributor Author

pass-lin commented Apr 2, 2025

It seems like nobody is against it. Then I will bring up PR in the near future.

@divyashreepathihalli
Copy link
Collaborator

Hi @pass-lin - this is a new kind of model compared to the ones we have in KerasHub. Looks like we might need a new task class for this. Can you provide a quick code usage example to see how the overall API would look like?

for example the code usage for ResNet looks like this

# Built-in, pretrained weights.
keras_hub.models.ResNetBackbone.from_preset("resnet_18")

# Construct with the base class.
keras_hub.models.Backbone.from_preset("resnet_50")
# Specify dtype.
keras_hub.models.ResNetBackbone.from_preset("resnet_50", dtype="mixed_bfloat16")
# Convert from timm.
keras_hub.models.ResNetBackbone.from_preset("hf://timm/resnet18.a1_in1k")
# Just the config.
keras_hub.models.ResNetBackbone.from_preset("resnet_50", load_weights=False)
# Minimal custom backbone (only args with no defaults).
keras_nlp.models.ResNetBackbone(
    stackwise_num_filters=[64, 64, 64],
    stackwise_num_blocks=[2, 2, 2],
    stackwise_num_strides=[1, 2, 2],
    block_type="basic_block",
)

# Direct backbone usage.
resized_batch = keras.layers.Resizing(224, 224)(image_batch)
backbone = keras_hub.models.ResNetBackbone.from_preset("resnet_18")
outputs = backbone(resized_batch)
# Feature pyramid usage.
TODO

# Classification task usage.
task = keras_hub.models.ResNetImageClassifier.from_preset(
   "resnet_50", num_classes=2, activation="softmax",
)
task.fit(classification_dataset) # Resizes all images.
task.predict(image_batch) # Resizes all images.
task.preprocessing = None
task.predict(resized_batch)

@pass-lin
Copy link
Contributor Author

pass-lin commented Apr 5, 2025

Hi @pass-lin - this is a new kind of model compared to the ones we have in KerasHub. Looks like we might need a new task class for this. Can you provide a quick code usage example to see how the overall API would look like?

for example the code usage for ResNet looks like this

# Built-in, pretrained weights.
keras_hub.models.ResNetBackbone.from_preset("resnet_18")

# Construct with the base class.
keras_hub.models.Backbone.from_preset("resnet_50")
# Specify dtype.
keras_hub.models.ResNetBackbone.from_preset("resnet_50", dtype="mixed_bfloat16")
# Convert from timm.
keras_hub.models.ResNetBackbone.from_preset("hf://timm/resnet18.a1_in1k")
# Just the config.
keras_hub.models.ResNetBackbone.from_preset("resnet_50", load_weights=False)
# Minimal custom backbone (only args with no defaults).
keras_nlp.models.ResNetBackbone(
    stackwise_num_filters=[64, 64, 64],
    stackwise_num_blocks=[2, 2, 2],
    stackwise_num_strides=[1, 2, 2],
    block_type="basic_block",
)

# Direct backbone usage.
resized_batch = keras.layers.Resizing(224, 224)(image_batch)
backbone = keras_hub.models.ResNetBackbone.from_preset("resnet_18")
outputs = backbone(resized_batch)
# Feature pyramid usage.
TODO

# Classification task usage.
task = keras_hub.models.ResNetImageClassifier.from_preset(
   "resnet_50", num_classes=2, activation="softmax",
)
task.fit(classification_dataset) # Resizes all images.
task.predict(image_batch) # Resizes all images.
task.preprocessing = None
task.predict(resized_batch)

Okay, we can give it a try. One advantage of starting with ESM2 is that it is essentially a BERT in the field of protein. Therefore, we can try to follow the API of BERT.

keras_hub.models.ESM2Backbone.from_preset()
keras_hub.models.ESM2MaskedLM.from_preset()
keras_hub.models.ESM2MaskedLMPreprocessor.from_preset()
keras_hub.models.ESM2ProteinClassifier.from_preset()
keras_hub.models.ESM2ProteinClassifierPreprocessor.from_preset()
keras_hub.models.ESM2Tokenizer.from_preset()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants