Skip to content

model: Add support for GLM 4.5 family of models (#14921) #14939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 4, 2025

Conversation

sammcj
Copy link
Contributor

@sammcj sammcj commented Jul 29, 2025

Add support for the newly released GLM 4.5 family of models.

Core Architecture

  • Architecture Registration: Added LLM_ARCH_GLM4_MOE enum and architecture mappings
  • Tensor Definitions: Complete tensor mappings for MoE components including 128 routed experts + 1 shared expert
  • Hybrid Layer Support: Added n_layer_dense_lead parameter to handle different dense/MoE layer patterns between variants

Model Loading (src/llama-model.cpp)

  • Multi-variant Support: Automatic detection and loading for both 47-layer (Air) and 93-layer (full) models
  • MoE Infrastructure: Complete expert weight loading with merged 3D tensor format
  • Graph Implementation: New llm_build_glm4_moe class with sigmoid-based expert routing and top-8 selection
  • Shared Expert Integration: Proper handling of shared expert computation alongside routed experts

Conversion Support (convert_hf_to_gguf.py)

  • HuggingFace Integration: Complete Glm4MoeModel converter class
  • Expert Tensor Merging: Sophisticated logic to merge expert weights into GGUF 3D tensor format
  • Metadata Handling: Proper extraction and conversion of MoE parameters from HuggingFace config

Technical Details

MoE Architecture

  • Expert Count: 128 routed experts + 1 shared expert per MoE layer
  • Expert Selection: Top-8 experts per token with sigmoid-based routing (not softmax)
  • Hybrid Layers: Dense layer for layer 0, MoE for remaining layers
  • Weight Format: Expert weights stored as merged [num_experts, hidden_size, ffn_size] tensors

Model Variants

  • GLM-4.5: 355B total parameters, 32B active, 93 layers, includes K/Q norm tensors
  • GLM-4.5-Air: 106B total parameters, 12B active, 47 layers, no K/Q norm tensors

The NextN/MTP prediction tensors are preserved during conversion but marked as unused since llama.cpp does not yet support multi-token prediction.

Testing

  • Builds successfully with no compilation errors.
  • convert_hf_to_gguf.py working.
  • llama-quantize working.

CI scripts run locally (CPU only) have two failing tests that I believe are unrelated to this change (please tell me if this isn't the case!):

94% tests passed, 2 tests failed out of 35

Label Time Summary:
main    = 251.60 sec*proc (35 tests)

Total Test time (real) = 251.61 sec

The following tests FAILED:
	 14 - test-tokenizers-ggml-vocabs (Failed)
	 27 - test-thread-safety (Subprocess aborted)


Analysis of Test Failures

1. test-tokenizers-ggml-vocabs - Corrupted Test Files

gguf_init_from_file_impl: invalid magic characters: 'vers', expected 'GGUF'
- Issue: Corrupted GGUF vocabulary files (ggml-vocab-nomic-bert-moe.gguf, etc.)
- Cause: File corruption in test environment, not code changes
- Relation to GLM 4.5: None - this is about vocabulary files, not architecture definitions

2. test-thread-safety - CUDA Environment Issues

CUDA error: unspecified launch failure
current device: 1, in function ggml_backend_cuda_synchronize
- Issue: CUDA backend threading/synchronisation failure
- Cause: CUDA driver/environment issues in CI system
- Relation to GLM 4.5: None - our changes were all CPU-side model loading logic

gguf-dump
```plain

TODO when ready


```

Disclaimer:

  • I am certainly not an expert in this - I think this is my first attempt at contributing a new model architecture to llama.cpp.
  • The most useful feedback is the code changes to make.
  • I did leverage the smarts of AI to help with the changes.
  • If this is not up to standard or I am completely off track, please feel free to reject this PR, I totally understand if someone smarter than I could do a better job of it.

Hopefully resolves #14921

@github-actions github-actions bot added the python python script changes label Jul 29, 2025
@sammcj sammcj force-pushed the glm-4-5 branch 2 times, most recently from 5da3811 to ec5c193 Compare July 29, 2025 08:44
@CISC
Copy link
Collaborator

CISC commented Jul 29, 2025

Just a few quick notes from a glance:

  • Please name it GLM4_MOE, not GLM45
  • There's already LLM_KV_LEADING_DENSE_BLOCK_COUNT, no need for LLM_KV_FIRST_K_DENSE_REPLACE
  • Use GGML_ASSERT instead of throwing
  • Be mindful of whitespaces and alignments

Will do a proper review when you are ready. :)

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

Hey @CISC no worries on the naming etc.. will do.
Whitespace changes will be fixed, I haven't run this through linting yet, will get back to this later tonight hopefully.

@sammcj sammcj force-pushed the glm-4-5 branch 2 times, most recently from c4dbf69 to b4c60e1 Compare July 29, 2025 11:24
@AnneKitsune
Copy link

FYI when trying to run convert_hf_to_gguf.py on GLM4.5-Air-FP8, I get that some constants ending with _EXPS don't exist. If I replace these by _EXP, then I get a different error related to matrix mapping.
Thank you for working on this!

@CISC
Copy link
Collaborator

CISC commented Jul 29, 2025

FYI when trying to run convert_hf_to_gguf.py on GLM4.5-Air-FP8, I get that some constants ending with _EXPS don't exist. If I replace these by _EXP, then I get a different error related to matrix mapping. Thank you for working on this!

That's because converting FP8 weights isn't supported yet, see #14810

@sammcj sammcj force-pushed the glm-4-5 branch 2 times, most recently from 1957023 to 4397ccb Compare July 29, 2025 12:33
@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

I'm close to having convert_hf_to_gguf.py and llama-quantize working (see updated PR), it completes conversion without error and I was then able to quantise to Q4_K_M.

gguf-dump worked, but llama-server picked up a tensor mapping issue with token_embd.weight, so I've just put a fix into convert_hf_to_gguf.py.

I'm going through the whole conversion then quantisation process again, it's getting late here (Hi from Melbourne 👋), so I'll come back and see if it's finished in 20~.

@pwilkin
Copy link
Contributor

pwilkin commented Jul 29, 2025

The LLM_TYPE code is wrong, those models aren't (respectively) dense 12B and 32B models. You have to add new MoE constants for them (see Qwen3 and Ernie MoEs as examples).

@pwilkin
Copy link
Contributor

pwilkin commented Jul 29, 2025

Also, you might want to include the nextn tensors instead of throwing them out - MTP support is not there yet, but that way you won't have to reconvert and requantize if/when it arrives.

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

Thanks @pwilkin, LLM_TYPE updated.

I've added the nextn tensors into the conversion, skipping mapping to avoid errors.

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

Note that preserving the nextn tensors does result in a larger GGUF (780 tensors -> 1184 & 214GB -> 221GB for the f16)

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

I can't replicate that error @Thireus

@pwilkin
Copy link
Contributor

pwilkin commented Jul 29, 2025

Note that preserving the nextn tensors does result in a larger GGUF (780 tensors -> 1184 & 214GB -> 221GB for the f16)

Obviously, but they won't get loaded since they're not supported 😄

Also, don't make my mistake:
"torch_dtype": "bfloat16"

Don't convert to f16, do --outtype bf16 or your model will probably have errors in the tensors.

@CISC
Copy link
Collaborator

CISC commented Jul 29, 2025

If you add unused tensors to the GGUF you must mark those tensors as unused (GGML_OP_NONE) in llama-arch.cpp, otherwise you will get an error when loading the model!

Just FYI, all other models with MTP so far have those tensors stripped.

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

If you add unused tensors to the GGUF you must mark those tensors as unused (GGML_OP_NONE) in llama-arch.cpp, otherwise you will get an error when loading the model!

Ah, that'd explain why I'm getting llama_model_load: error loading model: done_getting_tensors: wrong number of tensors; expected 1184, got 735! - I'll push a change for that shortly @CISC.

I'll have to come back to this in the morning as it's getting late here.

If anyone is keen for this ASAP and has improvements feel free to either raise a PR against my branch or pull my commits into a PR of your own if you have a better approach and I'll review in the morning.

@CISC
Copy link
Collaborator

CISC commented Jul 29, 2025

I'll just put it out there right now; no-one should make GGUFs from this PR public yet, there will be changes! :)

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

I'll just put it out there right now; no-one should make GGUFs from this PR public yet, there will be changes! :)

Absolutely, I hope people do not do that - it's very much in draft and I'm learning as I go.

@sammcj sammcj force-pushed the glm-4-5 branch 2 times, most recently from 9d6ea41 to 7f026fb Compare July 29, 2025 14:31
@Thireus
Copy link

Thireus commented Jul 29, 2025

@sammcj, 7f026fb#diff-4f653096980bd7d10518aa909cb648452cd3aa380ff93cb9fb642dca48536526 fixed the issue thanks.

@ricyoung
Copy link

the fix seems to work, still testing -> INFO:hf-to-gguf:Model successfully exported to models/glm-45-air-f16.gguf

@ashirviskas
Copy link

I am AMD and a Vulkan user with 88GB of VRAM, already downloading GLM-4.5-Air, will report after a few hours if I have any success with it.

Co-authored-by: Diego Devesa <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
@Noeda
Copy link
Contributor

Noeda commented Aug 4, 2025

Unfortunately for the context thing...90k context is coherent for me for the Air model so sounds like I can't reproduce it here.

I'm going to try with the big model too but I'm expecting it to be the same, and unless it breaks visibly so that I can troubleshoot it, I likely won't be reporting back on that one. I'll also try check GLM4 older issues if there was something else I did not myself see there, e.g. similar in spirit how I suggested f32-fix in my last comment, but you likely won't hear from me on this if I think I am useless here to help and don't have any information to contribute :)

I suggest once this PR is merged, someone who is personally affected opens an issue and some reproducible scenario, and especially note what is their backend/architecture (e.g. AMD+Vulkan+Windows) (well you should report your setup anyway but here extra important). If this is at all similar problem to the last GLM4-series problems, then it seems to be platform-specific, and maybe it is AMD/Vulkan/Windows thing (I don't think I ever found out what combo was bad exactly, that was just my best guess based on trying to ask people around what was their platform).

Edit: Found at least one PR that looks possibly relevant that I did not know happened: #13607 It unfortunately does not seem arch-specific (arch-specific as in GLM4 or GLM4_MOE), so possibly the current gibberish is of a different origin.

Edit2: I'll also test Vulkan-backend on macOS since that's also an option. If you don't see updates from me, assume I could not reproduce. (might mean Vulkan is not the issue, or is not the only required condition for the gibberish to happen).

@AesSedai
Copy link

AesSedai commented Aug 4, 2025

Hi, sorry, just woke up. I don't use Vulkan, I have a pair of 3090s and I use CUDA on a fedora-41 VM. If the issue is specific to me, which it seems to be, then I can try re-downloading the HF safetensors and re-converting / re-quanting. Either way, I don't think it should block the PR because it's working for others. I say ship it :)

Worst case maybe it's something that's resolved by me downloading an unsloth quant in the future or something.

@CISC CISC merged commit ef0144c into ggml-org:master Aug 4, 2025
49 of 51 checks passed
@segmond
Copy link

segmond commented Aug 4, 2025

Thank you very much to @sammcj for undertaking this effort, and of course special thanks to all that jumped in to help along the way. I'm about to have a very unproductive week.

@zRzRzRzRzRzRzR Can you please help us implement MTP? 🙏

@Mushoz
Copy link

Mushoz commented Aug 4, 2025

Is there a way to disable thinking on this model through a parameter?

@CISC
Copy link
Collaborator

CISC commented Aug 4, 2025

Is there a way to disable thinking on this model through a parameter?

Yes, the template supports enable_thinking, but you can also just add /nothink to the end of your prompt.

@Mushoz
Copy link

Mushoz commented Aug 4, 2025

For people having the same question as I did: Make sure you use --jinja, or the enable_thinking parameter won't work :)

@sammcj
Copy link
Contributor Author

sammcj commented Aug 4, 2025

A big thank you to @CISC for all your hard work on this one! 🙇

@jukofyork
Copy link
Collaborator

For people having the same question as I did: Make sure you use --jinja, or the enable_thinking parameter won't work :)

I'm a bit confused now as @sammcj posted this on Reddit not long ago:

Also - please do not use --jinja when loading the model as the official template that comes from their huggingface is broken and will cause issues.

Is there a working jinga template somewhere?

@sammcj
Copy link
Contributor Author

sammcj commented Aug 4, 2025

@jukofyork
image

@jukofyork
Copy link
Collaborator

@jukofyork image

Thanks!

@AesSedai
Copy link

AesSedai commented Aug 5, 2025

@CISC I narrowed down the gibberish issue a bit. It requires setting --batch-size 4096 --ubatch-size 4096 and possibly having a long multi-turn chat going. When I removed the batch-size / ubatch-size, my 40k and 50k token chats began working again. Setting the sizes up to 2048 / 2048 also worked. Something about 4096 / 4096 combined with over 32k context across multiple turns leads to that gibberish edge case.

I also tried a needle in a haystack test with a 35k token prompt with a direction to answer a question from the text as a one-shot and that worked. So I don't have a reproducible smoking gun, but batch-size / ubatch-size is involved and for now I'm just scaling them back to make it work.

@CISC
Copy link
Collaborator

CISC commented Aug 5, 2025

I also tried a needle in a haystack test with a 35k token prompt with a direction to answer a question from the text as a one-shot and that worked. So I don't have a reproducible smoking gun, but batch-size / ubatch-size is involved and for now I'm just scaling them back to make it work.

Ah, ok, so that means it's not a model issue then, that's great!

Submit an issue though. :)

@CISC
Copy link
Collaborator

CISC commented Aug 5, 2025

Just FYI for anyone wanting to create i-quants; as the final layer will not get imatrix data until MTP is supported it has to be overridden for lower quants to work, eg. using --tensor-type 46=iq4_xs or --tensor-type 92=iq4_xs.

cc/ @bartowski1182 @danielhanchen @nicoboss

@jacekpoplawski
Copy link
Contributor

I am getting over 45t/s on three 3090s on unsloth quant Q4 for GLM Air, here is the optimized command:

llama-server -ts 18/17/18 -ngl 99 -m ~/models/GLM-4.5-Air-UD-Q4_K_XL-00001-of-00002.gguf --n-cpu-moe 2 --jinja --host 0.0.0.0

@jukofyork
Copy link
Collaborator

1. It still seems to be skipping warmup. It's loading the model into system RAM **after** receiving the first prompt.

I can confirm it's not warming up.

Manually setting --override-kv glm4moe.expert_used_count=int:160 to try to get it to warm up triggers:

ggml_new_object: not enough space in the context's memory pool (needed 5730848, available 5730480)

If I patch src/llama-context.cpp:

uint32_t llama_context::graph_max_nodes() const {
    //return std::max<uint32_t>(1024u, 8u*model.n_tensors());
    return std::max<uint32_t>(65536u, 8u*model.n_tensors());
} 

and then run with --override-kv glm4moe.expert_used_count=int:160 it warms up fine.

You then need to rerun without --override-kv glm4moe.expert_used_count=int:160.

I've got to go out so no more time to investigate until later.

@jukofyork
Copy link
Collaborator

Actually, no it's still not warming up properly - it's just a lot quicker because it's got the experts mmapped I think... Will see if I can figure it out later if nobody else has by then.

Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Aug 5, 2025
* model: Add GLM 4.5 (ggml-org#14921)

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Merge in PR suggestions

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* model: Add GLM 4.5 family of models (ggml-org#14921)

1. Updated tensor_mapping.py with NextN tensor mappings

- Added proper tensor mappings for all NextN/MTP tensors in /Users/samm/git/llama.cpp/gguf-py/gguf/tensor_mapping.py
- Added mappings for: eh_proj, embed_tokens, enorm, hnorm, shared_head.head, shared_head.norm

2. Added num_nextn_predict_layers configuration

- Added LLM_KV_NUM_NEXTN_PREDICT_LAYERS constant to llama-arch.h and llama-arch.cpp
- Added num_nextn_predict_layers field to llama_hparams struct
- Updated GLM4_MOE parameter loading in llama-model.cpp to read this parameter
- Modified tensor loading logic to conditionally load NextN tensors based on num_nextn_predict_layers
- Added GGUF writer support in gguf_writer.py with add_num_nextn_predict_layers() method
- Updated conversion script to extract and write this parameter from HuggingFace config

3. Added FIM tokens for GLM4_MOE

- Added GLM-4.5's FIM tokens to llama-vocab.cpp:
  - <|code_prefix|> for FIM_PRE
  - <|code_suffix|> for FIM_SUF
  - <|code_middle|> for FIM_MID

4. Removed manual NextN tensor handling

- Removed the special-case handling in convert_hf_to_gguf.py that manually mapped NextN tensors
- NextN tensors are now handled automatically through the proper tensor mapping system

* glm 4.5 update tensors names

* model: glm 4.5 apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* model: glm 4.5 apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* model: glm 4.5 apply suggestions from code review

* Apply suggestions from code review

* patch broken chat template

* typings fix

* add TENSOR_SKIP flag

Co-authored-by: Diego Devesa <[email protected]>

* Update src/llama-model-loader.h

Co-authored-by: Sigbjørn Skjæret <[email protected]>

---------

Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Diego Devesa <[email protected]>
@jukofyork
Copy link
Collaborator

I've found it:

                // MoE layer with shared experts
                //const int64_t n_expert      = hparams.n_expert;
                //const int64_t n_expert_used = hparams.n_expert_used;

                // Process routed experts using existing MoE infrastructure
                ggml_tensor * routed_out = build_moe_ffn(cur,
                        model.layers[il].ffn_gate_inp,
                        model.layers[il].ffn_up_exps,
                        model.layers[il].ffn_gate_exps,
                        model.layers[il].ffn_down_exps,
                        model.layers[il].ffn_exp_probs_b,
                        n_expert, n_expert_used,
                        LLM_FFN_SILU, hparams.expert_weights_norm,
                        true, hparams.expert_weights_scale,
                        (llama_expert_gating_func_type) hparams.expert_gating_func,
                        il);
                cb(routed_out, "ffn_moe_out", il);

The local n_expert and n_expert_used were shadowing those set here:

llm_graph_context::llm_graph_context(const llm_graph_params & params) :
    arch             (params.arch),
    hparams          (params.hparams),
    cparams          (params.cparams),
    ubatch           (params.ubatch),
    n_embd           (hparams.n_embd),
    n_layer          (hparams.n_layer),
    n_rot            (hparams.n_rot),
    n_ctx            (cparams.n_ctx),
    n_head           (hparams.n_head()),
    n_head_kv        (hparams.n_head_kv()),
    n_embd_head_k    (hparams.n_embd_head_k),
    n_embd_k_gqa     (hparams.n_embd_k_gqa()),
    n_embd_head_v    (hparams.n_embd_head_v),
    n_embd_v_gqa     (hparams.n_embd_v_gqa()),
    n_expert         (hparams.n_expert),
    n_expert_used    (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
    freq_base        (cparams.rope_freq_base),
    freq_scale       (cparams.rope_freq_scale),
    ext_factor       (cparams.yarn_ext_factor),
    attn_factor      (cparams.yarn_attn_factor),
    beta_fast        (cparams.yarn_beta_fast),
    beta_slow        (cparams.yarn_beta_slow),
    norm_eps         (hparams.f_norm_eps),
    norm_rms_eps     (hparams.f_norm_rms_eps),
    n_tokens         (ubatch.n_tokens),
    n_outputs        (params.n_outputs),
    n_ctx_orig       (cparams.n_ctx_orig_yarn),
    pooling_type     (cparams.pooling_type),
    rope_type        (hparams.rope_type),
    sched            (params.sched),
    backend_cpu      (params.backend_cpu),
    cvec             (params.cvec),
    loras            (params.loras),
    mctx             (params.mctx),
    cross            (params.cross),
    cb_func          (params.cb),
    res              (params.res),
    ctx0             (res->get_ctx()),
    gf               (res->get_gf()) {
        res->set_params(params);
    }

@jukofyork
Copy link
Collaborator

#15088

@createthis
Copy link

createthis commented Aug 5, 2025

@jukofyork confirmed. This fixes warmup for me. It also restores the GLM-4.5 to the performance levels I've come to expect from llama.cpp:

Screenshot 2025-08-05 at 9 14 14 AM

Startup command:

./build/bin/llama-server \
    --model /data/GLM-4.5-GGUF/q4_k_m/GLM-4.5-Q4_K_M.gguf \
    --alias GLM-4.5-GGUF:q4_k_m \
    --no-webui \
    --numa numactl \
    --threads 32 \
    --ctx-size 131072 \
    --n-gpu-layers 94 \
    -ot "blk\.(3|4|5|6|7|8|9|10|11|12|13|14|15|16|17)\.ffn_.*=CUDA0" \
    -ot exps=CPU \
    -ub 4096 -b 4096 \
    --seed 3407 \
    --temp 0.6 \
    --top-p 1.0 \
    --log-colors \
    --flash-attn \
    --host 0.0.0.0 \
    --jinja \
    --port 11434

I had GLM-4.5 write a poem for you:

Jukofyork, with skillful hand,
Commit c81de6e fixed the land.
GLM-4.5 warmup, once so slow,
Now performs with steady glow.
Removed those lines that caused the pain,
Llama.cpp runs fast again.

@jukofyork
Copy link
Collaborator

No problem and I can confirm it's running as expected for me now too (~6.5 tokens/s generation).

I'm managed to transplant the vocab into Qwen2.5-Coder-0.5B-Instruct:

Loading config from 'Qwen2.5-Coder-0.5B-Instruct'... Done.
Loading config from 'GLM-4.5'... Done.
Loading tokenizer from 'Qwen2.5-Coder-0.5B-Instruct'... Done.
Loading tokenizer from 'GLM-4.5'... Done.
Loading model from 'Qwen2.5-Coder-0.5B-Instruct'... Done.

Input model configuration:
- Target vocabulary size    : 151552 (used = 151365, unused = 187)
- Donor vocabulary size     : 151936
- Donor num layers          : 24 (tied embeddings = True)
- Donor hidden size         : 896
- Donor attention heads     : 14
- Donor intermediate size   : 4864 (ratio = 1:5.4)
- Donor total parameters    : 494032768 (0.49B)
-- Embedding parameters     : 136134656 (0.14B)
-- Non-embedding parameters : 357898112 (0.36B)

Processing 3 automatic token overrides:
✘ 'bos_token_id' : Not found for target model
✔ 'eos_token_id' : 151329 '<|endoftext|>' → [151645] '<|im_end|>'
✘ 'pad_token_id' : 151329 is already mapped to [151645]

Processing 14 manual token overrides:
✔ 151329 : '<|endoftext|>' → [151643] '<|endoftext|>'
✔ 151330 : '[MASK]' → [151643] '<|endoftext|>'
✔ 151331 : '[gMASK]' → [151643] '<|endoftext|>'
✔ 151332 : '[sMASK]' → [151643] '<|endoftext|>'
✔ 151333 : '<sop>' → [151643] '<|endoftext|>'
✔ 151334 : '<eop>' → [151643] '<|endoftext|>'
✔ 151335 : '<|system|>' → [151644, 8948] '<|im_start|>system'
✔ 151336 : '<|user|>' → [151644, 872] '<|im_start|>user'
✔ 151337 : '<|assistant|>' → [151644, 77091] '<|im_start|>assistant'
✔ 151338 : '<|observation|>' → [151644, 872] '<|im_start|>user'
✔ 151352 : '<tool_call>' → [151657] '<tool_call>'
✔ 151353 : '</tool_call>' → [151658] '</tool_call>'
✔ 151354 : '<tool_response>' → [151657] '<tool_call>'
✔ 151355 : '</tool_response>' → [151658] '</tool_call>'

NOTE: Using an "untied" copy of 'embed_tokens.weight' as new 'lm_head.weight' tensor...

Transplanting tokens: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 151365/151365 [00:42<00:00, 3558.20token/s]

Transplant mappings:
- 1 to 1  : 123102 (81%)
- 2 to 1  : 23944 (16%)
- 3 to 1  : 3262 (2.2%)
- 4 to 1  : 821 (0.54%)
- 5 to 1  : 181 (0.12%)
- 6 to 1  : 26 (0.017%)
- 7 to 1  : 21 (0.014%)
- 8 to 1  : 5 (0.0033%)
- 9 to 1  : 1 (0.00066%)
- 13 to 1 : 1 (0.00066%)
- 16 to 1 : 1 (0.00066%)

Head initialized with:
- Copies : 123102 (81%)
- Means  : 28263 (19%)
- Zeros  : 187 (0.12%)

Output model configuration:
- Output vocabulary size    : 151552
- Output num layers         : 24 (tied embeddings = False)
- Output hidden size        : 896
- Output attention heads    : 14
- Output intermediate size  : 4864 (ratio = 1:5.4)
- Output total parameters   : 629479296 (0.63B)
-- Embedding parameters     : 271581184 (0.27B)
-- Non-embedding parameters : 357898112 (0.36B)

Saving model and tokenizer to 'GLM-4.5-DRAFT-0.6B-UNTRAINED' folder

so assuming it trains OK, then we should have a draft model in a day or so.

It actually looks to have transplanted very well, as even the untrained draft is getting a high acceptance rate for refactoring tasks:

prompt eval time =   59625.37 ms /  2339 tokens (   25.49 ms per token,    39.23 tokens per second)
       eval time =  288397.17 ms /  3170 tokens (   90.98 ms per token,    10.99 tokens per second)
      total time =  348022.54 ms /  5509 tokens
slot print_timing: id  0 | task 0 | 
draft acceptance rate = 0.74499 ( 2080 accepted /  2792 generated)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model Model specific python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: GLM 4.5 MoE support