Skip to content

Commit 73ca767

Browse files
Adds support for gemma_270m to checkpoint converter (#2380)
* Adds support for gemma_270m to checkpoint converter * format * removes assertion * format * removes launch.json
1 parent f0f83fd commit 73ca767

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

tools/checkpoint_conversion/convert_gemma3_checkpoints.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
Usage:
1111
```shell
1212
cd tools/checkpoint_conversion
13-
python convert_gemma_checkpoints.py --preset gemma3_instruct_1b
14-
python convert_gemma_checkpoints.py --preset gemma3_instruct_4b
13+
python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b
14+
python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b
1515
```
1616
"""
1717

@@ -43,6 +43,15 @@
4343

4444
PRESET_MAP = {
4545
# === Text ===
46+
# 270M
47+
"gemma3_instruct_270m": {
48+
"model": gm.nn.Gemma3_270M,
49+
"params": gm.ckpts.CheckpointPath.GEMMA3_270M_IT,
50+
},
51+
"gemma3_270m": {
52+
"model": gm.nn.Gemma3_270M,
53+
"params": gm.ckpts.CheckpointPath.GEMMA3_270M_PT,
54+
},
4655
# 1B
4756
"gemma3_1b": {
4857
"model": gm.nn.Gemma3_1B,
@@ -493,7 +502,6 @@ def validate_output(
493502
params=flax_params,
494503
multi_turn=False,
495504
cache_length=256 if length <= 256 else 512,
496-
# max_out_length=length,
497505
)
498506
flax_output = flax_sampler.chat(input_str, images=image)
499507
print("🔶 Flax output:", flax_output)
@@ -508,11 +516,11 @@ def main(_):
508516
assert preset in presets, (
509517
f"Invalid preset {preset}. Must be one of {','.join(presets)}"
510518
)
511-
text_only = "text" in preset or "1b" in preset
519+
text_only = "text" in preset or "1b" in preset or "270m" in preset
512520

513521
print("🏃 Loading Flax model and tokeniser")
514522
flax_kwargs = {}
515-
if text_only and "1b" not in preset:
523+
if text_only and "1b" not in preset and "270m" not in preset:
516524
flax_kwargs["text_only"] = True
517525
flax_model = PRESET_MAP[preset]["model"](**flax_kwargs)
518526
flax_config = flax_model.config

0 commit comments

Comments
 (0)