From 50ef2f56b38e17aafa8913bbefa760274ffaf7d9 Mon Sep 17 00:00:00 2001 From: hemildesai Date: Thu, 23 Oct 2025 00:11:11 +0000 Subject: [PATCH 1/2] feat: Add GLM 4/4.5/4.6 MoE Signed-off-by: hemildesai --- .../components/models/glm4_moe/__init__.py | 13 + .../components/models/glm4_moe/layers.py | 147 ++++++ .../components/models/glm4_moe/model.py | 285 ++++++++++ .../models/glm4_moe/state_dict_adapter.py | 82 +++ .../components/models/gpt_oss/model.py | 4 +- .../components/models/qwen3_moe/model.py | 4 +- .../components/models/qwen3_next/model.py | 4 +- nemo_automodel/components/moe/layers.py | 4 +- tests/unit_tests/models/glm4_moe/__init__.py | 13 + .../models/glm4_moe/test_glm4_moe_layers.py | 377 ++++++++++++++ .../models/glm4_moe/test_glm4_moe_model.py | 493 ++++++++++++++++++ 11 files changed, 1417 insertions(+), 9 deletions(-) create mode 100644 nemo_automodel/components/models/glm4_moe/__init__.py create mode 100644 nemo_automodel/components/models/glm4_moe/layers.py create mode 100644 nemo_automodel/components/models/glm4_moe/model.py create mode 100644 nemo_automodel/components/models/glm4_moe/state_dict_adapter.py create mode 100644 tests/unit_tests/models/glm4_moe/__init__.py create mode 100644 tests/unit_tests/models/glm4_moe/test_glm4_moe_layers.py create mode 100644 tests/unit_tests/models/glm4_moe/test_glm4_moe_model.py diff --git a/nemo_automodel/components/models/glm4_moe/__init__.py b/nemo_automodel/components/models/glm4_moe/__init__.py new file mode 100644 index 000000000..070b8c0d7 --- /dev/null +++ b/nemo_automodel/components/models/glm4_moe/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_automodel/components/models/glm4_moe/layers.py b/nemo_automodel/components/models/glm4_moe/layers.py new file mode 100644 index 000000000..a42b08aa5 --- /dev/null +++ b/nemo_automodel/components/models/glm4_moe/layers.py @@ -0,0 +1,147 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +from torch import nn +from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + +from nemo_automodel.components.attention.utils import ( + initialize_attn_module_and_func, + postprocess_output_for_attn, + preprocess_args_and_kwargs_for_attn, +) +from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb +from nemo_automodel.components.moe.utils import ( + BackendConfig, + initialize_linear_module, + initialize_rms_norm_module, +) + + +class Glm4MoeAttention(nn.Module): + """GLM4 MoE attention with optional query/key per-head RMSNorm + partial RoPE. + + Key differences from Qwen3 MoE: + - Optional QK normalization (controlled by use_qk_norm) + - Partial rotary embeddings (controlled by partial_rotary_factor) + - o_proj bias is False (unlike Qwen3 which has configurable attention_bias) + + Shapes: + - Input: x -> [B, S, H] + - Projections: + q: [B, S, n_heads, head_dim] + k/v: [B, S, n_kv_heads, head_dim] -> repeated to n_heads via groups + - Output: [B, S, H] + """ + + def __init__(self, config: Glm4MoeConfig, backend: BackendConfig): + super().__init__() + self.backend = backend + self.config = config + + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) + self.use_qk_norm = config.use_qk_norm + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + + self.q_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_heads * self.head_dim, config.attention_bias + ) + self.k_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias + ) + self.v_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias + ) + self.o_proj = initialize_linear_module( + backend.linear, self.num_heads * self.head_dim, config.hidden_size, False + ) + + # Optional per-head RMSNorm for Q and K + if self.use_qk_norm: + self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) + + # Attention implementation + softmax_scale = self.head_dim**-0.5 + self.attn_module, self.attn_func = initialize_attn_module_and_func( + attn_impl=backend.attn, + num_attention_heads=self.num_heads, + num_qk_channels=self.head_dim, + num_v_channels=self.head_dim, + softmax_scale=softmax_scale, + num_gqa_groups=self.num_kv_heads, + ) + + def forward( + self, + x: torch.Tensor, + *, + freqs_cis: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if len(x.shape) == 2: + qkv_format = "thd" + num_tokens = x.shape[0] + else: + qkv_format = "bshd" + bsz, seqlen, _ = x.size() + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + if qkv_format == "thd": + q = q.view(num_tokens, self.num_heads, self.head_dim) + k = k.view(num_tokens, self.num_kv_heads, self.head_dim) + v = v.view(num_tokens, self.num_kv_heads, self.head_dim) + else: + q = q.view(bsz, seqlen, self.num_heads, self.head_dim) + k = k.view(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.num_kv_heads, self.head_dim) + + # Optional per-head RMSNorm + if self.use_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + # Partial RoPE (only apply to first partial_rotary_factor of head_dim) + rotary_dim = int(self.head_dim * self.partial_rotary_factor) + cos, sin = freqs_cis.split(rotary_dim // 2, dim=-1) + + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + + # Backend-specific attention + q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn( + q, k, v, attention_mask, self.backend.attn, **attn_kwargs + ) + out = self.attn_func(q, k, v, **_attn_kwargs) + out = postprocess_output_for_attn(out, self.backend.attn) + + flatten_dim = 2 if qkv_format == "bshd" else 1 + out = self.o_proj(out.flatten(flatten_dim)) + return out + + def init_weights(self, buffer_device: torch.device, init_std: float = 0.02): + linear_list = [self.q_proj, self.k_proj, self.v_proj, self.o_proj] + for linear in linear_list: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + if self.use_qk_norm: + for norm in (self.q_norm, self.k_norm): + norm.reset_parameters() diff --git a/nemo_automodel/components/models/glm4_moe/model.py b/nemo_automodel/components/models/glm4_moe/model.py new file mode 100644 index 000000000..57511545b --- /dev/null +++ b/nemo_automodel/components/models/glm4_moe/model.py @@ -0,0 +1,285 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +import torch.nn as nn +from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + +from nemo_automodel.components.models.glm4_moe.layers import Glm4MoeAttention +from nemo_automodel.components.models.glm4_moe.state_dict_adapter import Glm4MoeStateDictAdapter +from nemo_automodel.components.models.gpt_oss.rope_utils import RotaryEmbedding, position_ids_to_freqs_cis +from nemo_automodel.components.moe.fsdp_mixin import MoEFSDPSyncMixin +from nemo_automodel.components.moe.layers import MLP, MoE, MoEConfig +from nemo_automodel.components.moe.utils import BackendConfig, initialize_linear_module, initialize_rms_norm_module +from nemo_automodel.components.utils.model_utils import squeeze_input_for_thd +from nemo_automodel.shared.utils import dtype_from_str as get_dtype + + +class Block(nn.Module): + def __init__(self, layer_idx: int, config: Glm4MoeConfig, moe_config: MoEConfig, backend: BackendConfig): + super().__init__() + self.self_attn = Glm4MoeAttention(config, backend) + + # GLM4-MoE uses dense layers for first_k_dense_replace layers, then MoE + is_moe_layer = layer_idx >= config.first_k_dense_replace + if is_moe_layer: + self.mlp = MoE(moe_config, backend) + else: + self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear) + + self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps + ) + self.layer_idx = layer_idx + + def forward( + self, + x: torch.Tensor, + *, + freqs_cis: torch.Tensor, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if attention_mask is not None and padding_mask is None: + padding_mask = attention_mask.bool().logical_not() + + attn_out = self.self_attn( + x=self.input_layernorm(x), + freqs_cis=freqs_cis, + attention_mask=attention_mask, + **attn_kwargs, + ) + x = x + attn_out + + mlp_out = self._mlp(x=self.post_attention_layernorm(x), padding_mask=padding_mask) + x = x + mlp_out + return x + + def _mlp(self, x: torch.Tensor, padding_mask: torch.Tensor | None) -> torch.Tensor: + if isinstance(self.mlp, MLP): + return self.mlp(x) + else: + assert isinstance(self.mlp, MoE) + return self.mlp(x, padding_mask) + + def init_weights(self, buffer_device: torch.device): + for norm in (self.input_layernorm, self.post_attention_layernorm): + norm.reset_parameters() + self.self_attn.init_weights(buffer_device) + self.mlp.init_weights(buffer_device) + + +class Glm4MoeModel(nn.Module): + def __init__(self, config: Glm4MoeConfig, backend: BackendConfig, *, moe_config: MoEConfig | None = None): + super().__init__() + self.backend = backend + self.config = config + + # Map HF GLM4 MoE config -> our MoE wrapper + # GLM4 MoE config fields: + # - hidden_size, intermediate_size, moe_intermediate_size + # - n_routed_experts, n_shared_experts, num_experts_per_tok + # - n_group, topk_group, routed_scaling_factor, norm_topk_prob + self.moe_config = moe_config or MoEConfig( + dim=config.hidden_size, + inter_dim=config.intermediate_size, + moe_inter_dim=config.moe_intermediate_size, + n_routed_experts=config.n_routed_experts, + n_shared_experts=config.n_shared_experts, + n_activated_experts=config.num_experts_per_tok, + n_expert_groups=config.n_group, + n_limited_groups=config.topk_group, + train_gate=True, + gate_bias_update_factor=0.001, + score_func="sigmoid", # GLM4 MoE uses sigmoid scoring with groups + route_scale=config.routed_scaling_factor, + aux_loss_coeff=0.0, # GLM4 MoE doesn't use aux loss in the HF implementation + norm_topk_prob=config.norm_topk_prob, + expert_bias=False, + router_bias=False, + expert_activation="swiglu", + softmax_before_topk=False, # GLM4 uses sigmoid, not softmax + ) + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + ) + self.layers = torch.nn.ModuleDict() + for layer_id in range(config.num_hidden_layers): + self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend) + self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps) + + # Rotary embedding cache compatible with our rope_utils functions + # GLM4 MoE uses partial rotary embeddings + self.max_seq_len = config.max_position_embeddings + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + self.rotary_emb = RotaryEmbedding( + head_dim=self.head_dim, + base=config.rope_theta, + dtype=torch.float32, + scaling_factor=1.0, + device=torch.device(f"cuda:{torch.cuda.current_device()}"), + partial_rotary_factor=config.partial_rotary_factor, + ) + + def forward( + self, + input_ids: torch.Tensor, + *, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if position_ids is None: + position_ids = ( + torch.arange(0, input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1) + ) + + # Compute freqs_cis from RotaryEmbedding inv_freq and current position_ids; then concat [cos, sin] + freqs_cis = position_ids_to_freqs_cis( + self.rotary_emb, position_ids, qkv_format=attn_kwargs.get("qkv_format", "bshd") + ) + + h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids + + for layer in self.layers.values(): + h = layer( + x=h, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + padding_mask=padding_mask, + **attn_kwargs, + ) + + h = self.norm(h) if self.norm else h + return h + + @torch.no_grad() + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + + with buffer_device: + if self.embed_tokens is not None: + nn.init.normal_(self.embed_tokens.weight) + if self.norm is not None: + self.norm.reset_parameters() + # Ensure rotary embedding uses correct device after dtype move + self.rotary_emb.device = buffer_device + + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + + +class Glm4MoeForCausalLM(nn.Module, MoEFSDPSyncMixin): + @classmethod + def from_config( + cls, + config: Glm4MoeConfig, + moe_config: MoEConfig | None = None, + backend: BackendConfig | None = None, + **kwargs, + ): + return cls(config, moe_config, backend, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + *model_args, + **kwargs, + ): + config = Glm4MoeConfig.from_pretrained(pretrained_model_name_or_path) + return cls.from_config(config, *model_args, **kwargs) + + def __init__( + self, + config: Glm4MoeConfig, + moe_config: MoEConfig | None = None, + backend: BackendConfig | None = None, + **kwargs, + ): + super().__init__() + self.config = config + self.backend = backend or BackendConfig() + self.model = Glm4MoeModel(config, backend=self.backend, moe_config=moe_config) + self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False) + if self.backend.enable_hf_state_dict_adapter: + self.state_dict_adapter = Glm4MoeStateDictAdapter( + self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16) + ) + + def forward( + self, + input_ids: torch.Tensor, + *, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **attn_kwargs: Any, + ) -> torch.Tensor: + if "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd": + input_ids, position_ids, padding_mask, attn_kwargs = squeeze_input_for_thd( + input_ids, position_ids, padding_mask, attn_kwargs + ) + attention_mask = None + + hidden = self.model( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + padding_mask=padding_mask, + **attn_kwargs, + ) + logits = self.lm_head(hidden) if self.lm_head else hidden + if "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd": + logits = logits.unsqueeze(0) + return logits + + @torch.no_grad() + def initialize_weights( + self, buffer_device: torch.device | None = None, dtype: torch.dtype = torch.bfloat16 + ) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + self.model.init_weights(buffer_device=buffer_device) + final_out_std = self.config.hidden_size**-0.5 + cutoff_factor = 3 + if self.lm_head is not None: + nn.init.trunc_normal_( + self.lm_head.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + self.to(dtype) + for layer in self.model.layers.values(): + if isinstance(layer.mlp, MoE): + layer.mlp.gate.e_score_correction_bias = torch.zeros( + (self.config.n_routed_experts), dtype=torch.float32 + ).to(buffer_device) + with buffer_device: + # Ensure rotary embedding uses correct device after dtype move + self.model.rotary_emb.device = buffer_device + + +ModelClass = Glm4MoeForCausalLM diff --git a/nemo_automodel/components/models/glm4_moe/state_dict_adapter.py b/nemo_automodel/components/models/glm4_moe/state_dict_adapter.py new file mode 100644 index 000000000..461266010 --- /dev/null +++ b/nemo_automodel/components/models/glm4_moe/state_dict_adapter.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Optional + +import torch +from torch.distributed.device_mesh import DeviceMesh + +from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter +from nemo_automodel.components.moe.layers import MoEConfig +from nemo_automodel.components.moe.state_dict_mixin import MoESplitExpertsStateDictMixin +from nemo_automodel.components.moe.utils import BackendConfig + +logger = logging.getLogger(__name__) + + +class Glm4MoeStateDictAdapter(MoESplitExpertsStateDictMixin, StateDictAdapter): + """Converts between HF GLM4-MoE checkpoints and our grouped-experts native format. + + GLM4-MoE HF experts use keys: + model.layers.{L}.mlp.experts.{E}.gate_proj.weight + model.layers.{L}.mlp.experts.{E}.up_proj.weight + model.layers.{L}.mlp.experts.{E}.down_proj.weight + model.layers.{L}.mlp.shared_experts.gate_proj.weight + model.layers.{L}.mlp.shared_experts.up_proj.weight + model.layers.{L}.mlp.shared_experts.down_proj.weight + + Our native format groups them into: + model.layers.{L}.mlp.experts.gate_and_up_projs # [n_experts, dim, 2*moe_inter_dim] + model.layers.{L}.mlp.experts.down_projs # [n_experts, moe_inter_dim, dim] + model.layers.{L}.mlp.shared_expert.gate_proj.weight + model.layers.{L}.mlp.shared_expert.up_proj.weight + model.layers.{L}.mlp.shared_expert.down_proj.weight + """ + + def __init__( + self, + config: Any, + moe_config: MoEConfig, + backend: BackendConfig, + dtype: torch.dtype = torch.float32, + ): + self.config = config + self.moe_config = moe_config + self.backend = backend + self.dtype = dtype + self._uses_model_prefix = True + + def to_hf( + self, state_dict: dict[str, Any], exclude_key_regex: Optional[str] = None, quantization: bool = False, **kwargs + ) -> dict[str, Any]: + hf_state_dict = self._to_hf_w_split_experts(state_dict) + if exclude_key_regex: + import re + + hf_state_dict = {k: v for k, v in hf_state_dict.items() if not re.match(exclude_key_regex, k)} + return hf_state_dict + + def from_hf( + self, + hf_state_dict: dict[str, Any], + device_mesh: Optional["DeviceMesh"] = None, + **kwargs, + ) -> dict[str, Any]: + # Detect whether HF checkpoints use the "model." prefix + for key in hf_state_dict.keys(): + if ".mlp.experts." in key and key.endswith(".weight"): + self._uses_model_prefix = key.startswith("model.") + break + return self._from_hf_w_merged_experts(hf_state_dict, device_mesh) diff --git a/nemo_automodel/components/models/gpt_oss/model.py b/nemo_automodel/components/models/gpt_oss/model.py index 1387e44f7..508797df8 100644 --- a/nemo_automodel/components/models/gpt_oss/model.py +++ b/nemo_automodel/components/models/gpt_oss/model.py @@ -89,8 +89,8 @@ def __init__(self, config: GptOssConfig, backend: BackendConfig, *, moe_config: n_routed_experts=config.num_local_experts, n_shared_experts=0, n_activated_experts=config.num_experts_per_tok, - n_expert_groups=getattr(config, "n_group", 1), - n_limited_groups=getattr(config, "topk_group", 1), + n_expert_groups=0, + n_limited_groups=0, train_gate=True, gate_bias_update_factor=0, score_func="softmax", diff --git a/nemo_automodel/components/models/qwen3_moe/model.py b/nemo_automodel/components/models/qwen3_moe/model.py index ec13d84ad..16af3f155 100644 --- a/nemo_automodel/components/models/qwen3_moe/model.py +++ b/nemo_automodel/components/models/qwen3_moe/model.py @@ -104,8 +104,8 @@ def __init__(self, config: Qwen3MoeConfig, backend: BackendConfig, *, moe_config n_routed_experts=getattr(config, "num_experts", 0), n_shared_experts=0, n_activated_experts=getattr(config, "num_experts_per_tok", 1), - n_expert_groups=1, - n_limited_groups=1, + n_expert_groups=0, + n_limited_groups=0, train_gate=True, gate_bias_update_factor=0.0, score_func="softmax", # Qwen3 uses softmax topk routing diff --git a/nemo_automodel/components/models/qwen3_next/model.py b/nemo_automodel/components/models/qwen3_next/model.py index 6b0f3c736..995f3905c 100644 --- a/nemo_automodel/components/models/qwen3_next/model.py +++ b/nemo_automodel/components/models/qwen3_next/model.py @@ -120,8 +120,8 @@ def __init__(self, config: Qwen3NextConfig, backend: BackendConfig, *, moe_confi n_routed_experts=config.num_experts, n_shared_experts=1, n_activated_experts=config.num_experts_per_tok, - n_expert_groups=1, - n_limited_groups=1, + n_expert_groups=0, + n_limited_groups=0, train_gate=True, gate_bias_update_factor=0.0, score_func="softmax", # Qwen3Next uses softmax topk routing diff --git a/nemo_automodel/components/moe/layers.py b/nemo_automodel/components/moe/layers.py index ba75d00da..441bc2f41 100644 --- a/nemo_automodel/components/moe/layers.py +++ b/nemo_automodel/components/moe/layers.py @@ -610,9 +610,7 @@ def __init__(self, config: MoEConfig): self.bias = None if self.bias_update_factor > 0: - self.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts, dtype=config.dtype), requires_grad=False - ) + self.register_buffer("e_score_correction_bias", torch.zeros((self.n_experts), dtype=config.dtype)) else: self.e_score_correction_bias = None diff --git a/tests/unit_tests/models/glm4_moe/__init__.py b/tests/unit_tests/models/glm4_moe/__init__.py new file mode 100644 index 000000000..070b8c0d7 --- /dev/null +++ b/tests/unit_tests/models/glm4_moe/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_tests/models/glm4_moe/test_glm4_moe_layers.py b/tests/unit_tests/models/glm4_moe/test_glm4_moe_layers.py new file mode 100644 index 000000000..583df7c7f --- /dev/null +++ b/tests/unit_tests/models/glm4_moe/test_glm4_moe_layers.py @@ -0,0 +1,377 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from unittest.mock import MagicMock, patch + +import pytest +import torch +from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + +from nemo_automodel.components.attention.utils import postprocess_output_for_attn, preprocess_args_and_kwargs_for_attn +from nemo_automodel.components.models.glm4_moe.layers import Glm4MoeAttention +from nemo_automodel.components.moe.utils import BackendConfig + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +@pytest.fixture +def config(): + cfg = Glm4MoeConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=256, + rms_norm_eps=1e-6, + attention_dropout=0.0, + use_qk_norm=True, + partial_rotary_factor=0.5, + attention_bias=False, + rope_theta=10000.0, + n_routed_experts=4, + n_shared_experts=1, + num_experts_per_tok=2, + first_k_dense_replace=1, + moe_intermediate_size=64, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + norm_topk_prob=False, + ) + cfg.head_dim = 16 + return cfg + + +@pytest.fixture +def config_without_qk_norm(): + cfg = Glm4MoeConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=256, + rms_norm_eps=1e-6, + attention_dropout=0.0, + use_qk_norm=False, + partial_rotary_factor=0.5, + attention_bias=False, + rope_theta=10000.0, + n_routed_experts=4, + n_shared_experts=1, + num_experts_per_tok=2, + first_k_dense_replace=1, + moe_intermediate_size=64, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + norm_topk_prob=False, + ) + cfg.head_dim = 16 + return cfg + + +@pytest.fixture +def sdpa_backend(): + return BackendConfig( + linear="torch", + attn="sdpa", + rms_norm="torch", + enable_deepep=False, + fake_balanced_gate=False, + enable_hf_state_dict_adapter=False, + ) + + +@pytest.fixture +def te_backend(): + return BackendConfig( + linear="torch", + attn="te", + rms_norm="torch", + enable_deepep=False, + fake_balanced_gate=False, + enable_hf_state_dict_adapter=False, + ) + + +class TestPreprocessForAttn: + def test_te_backend_without_mask_keeps_layout(self, te_backend): + q = torch.randn(2, 4, 2, 8) + k = torch.randn_like(q) + v = torch.randn_like(q) + + q_out, k_out, v_out, kwargs = preprocess_args_and_kwargs_for_attn(q, k, v, attention_mask=None, attn_impl=te_backend.attn) + + torch.testing.assert_close(q_out, q) + torch.testing.assert_close(k_out, k) + torch.testing.assert_close(v_out, v) + assert kwargs == {} + + def test_te_backend_with_mask_builds_padding_kwargs(self, te_backend): + q = torch.randn(1, 3, 2, 4) + k = torch.randn_like(q) + v = torch.randn_like(q) + attention_mask = torch.tensor([[1, 1, 0]], dtype=torch.bool) + + _, _, _, kwargs = preprocess_args_and_kwargs_for_attn(q, k, v, attention_mask=attention_mask, attn_impl=te_backend.attn) + + assert kwargs["attn_mask_type"] == "padding_causal" + assert kwargs["window_size"] == (-1, 0) + mask = kwargs["attention_mask"] + assert mask.shape == (1, 1, 1, 3) + expected = attention_mask.logical_not().unsqueeze(1).unsqueeze(2) + torch.testing.assert_close(mask, expected) + + def test_sdpa_backend_transposes_qkv(self, sdpa_backend): + q = torch.randn(2, 5, 3, 6) + k = torch.randn_like(q) + v = torch.randn_like(q) + + q_out, k_out, v_out, kwargs = preprocess_args_and_kwargs_for_attn(q, k, v, attention_mask=None, attn_impl=sdpa_backend.attn) + + assert q_out.shape == (2, 3, 5, 6) + assert k_out.shape == (2, 3, 5, 6) + assert v_out.shape == (2, 3, 5, 6) + assert kwargs == {"is_causal": True} + + +class TestPostprocessFromAttn: + def test_sdpa_backend_transposes_back(self, sdpa_backend): + x = torch.randn(2, 4, 6, 8) + + out = postprocess_output_for_attn(x, sdpa_backend.attn) + + assert out.shape == (2, 6, 4, 8) + torch.testing.assert_close(out.transpose(1, 2), x) + + def test_other_backend_returns_input(self, te_backend): + x = torch.randn(1, 2, 3, 4) + out = postprocess_output_for_attn(x, te_backend.attn) + torch.testing.assert_close(out, x) + + +class TestGlm4MoeAttention: + def test_initialization_populates_projections(self, config, sdpa_backend): + attention = Glm4MoeAttention(config, sdpa_backend) + + assert attention.num_heads == config.num_attention_heads + assert attention.num_kv_heads == config.num_key_value_heads + assert attention.head_dim == config.head_dim + assert attention.q_proj.in_features == config.hidden_size + assert attention.q_proj.out_features == config.num_attention_heads * config.head_dim + assert attention.k_proj.out_features == config.num_key_value_heads * config.head_dim + assert attention.v_proj.out_features == config.num_key_value_heads * config.head_dim + assert attention.o_proj.in_features == config.num_attention_heads * config.head_dim + assert attention.o_proj.out_features == config.hidden_size + + def test_initialization_creates_qk_norm_when_enabled(self, config, sdpa_backend): + config.use_qk_norm = True + attention = Glm4MoeAttention(config, sdpa_backend) + + assert hasattr(attention, "q_norm") + assert hasattr(attention, "k_norm") + assert attention.use_qk_norm is True + + def test_initialization_skips_qk_norm_when_disabled(self, config_without_qk_norm, sdpa_backend): + attention = Glm4MoeAttention(config_without_qk_norm, sdpa_backend) + + assert not hasattr(attention, "q_norm") + assert not hasattr(attention, "k_norm") + assert attention.use_qk_norm is False + + def test_partial_rotary_factor_is_stored(self, config, sdpa_backend): + config.partial_rotary_factor = 0.75 + attention = Glm4MoeAttention(config, sdpa_backend) + + assert attention.partial_rotary_factor == 0.75 + + def test_forward_shape_is_preserved_bshd_format(self, config, sdpa_backend): + attention = Glm4MoeAttention(config, sdpa_backend) + batch_size, seq_len = 2, 5 + hidden = torch.randn(batch_size, seq_len, config.hidden_size).to(torch.bfloat16) + freqs_cis = torch.randn(batch_size, seq_len, int(config.head_dim * config.partial_rotary_factor)) + + fake_attn = torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim) + attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) + + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + out = attention(hidden, freqs_cis=freqs_cis) + + assert out.shape == (batch_size, seq_len, config.hidden_size) + attention.attn_func.assert_called_once() + + def test_forward_shape_is_preserved_thd_format(self, config, sdpa_backend): + attention = Glm4MoeAttention(config, sdpa_backend) + num_tokens = 10 + hidden = torch.randn(num_tokens, config.hidden_size).to(torch.bfloat16) + freqs_cis = torch.randn(num_tokens, int(config.head_dim * config.partial_rotary_factor)) + + fake_attn = torch.zeros(num_tokens, config.num_attention_heads, config.head_dim) + attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) + + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + out = attention(hidden, freqs_cis=freqs_cis) + + assert out.shape == (num_tokens, config.hidden_size) + attention.attn_func.assert_called_once() + + def test_forward_applies_qk_norm_when_enabled(self, config, sdpa_backend): + """Test that q_norm and k_norm are applied when use_qk_norm=True""" + config.use_qk_norm = True + attention = Glm4MoeAttention(config, sdpa_backend) + batch_size, seq_len = 1, 3 + hidden = torch.randn(batch_size, seq_len, config.hidden_size).to(torch.bfloat16) + freqs_cis = torch.randn(batch_size, seq_len, int(config.head_dim * config.partial_rotary_factor)) + + fake_attn = torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) + attention.attn_func = MagicMock(return_value=fake_attn) + + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + with patch.object(attention.q_norm, "forward", wraps=attention.q_norm.forward) as mock_q_norm, \ + patch.object(attention.k_norm, "forward", wraps=attention.k_norm.forward) as mock_k_norm: + attention(hidden, freqs_cis=freqs_cis) + + mock_q_norm.assert_called_once() + mock_k_norm.assert_called_once() + + def test_forward_skips_qk_norm_when_disabled(self, config_without_qk_norm, sdpa_backend): + """Test that q_norm and k_norm are skipped when use_qk_norm=False""" + attention = Glm4MoeAttention(config_without_qk_norm, sdpa_backend) + batch_size, seq_len = 1, 3 + hidden = torch.randn(batch_size, seq_len, config_without_qk_norm.hidden_size).to(torch.bfloat16) + freqs_cis = torch.randn(batch_size, seq_len, int(config_without_qk_norm.head_dim * config_without_qk_norm.partial_rotary_factor)) + + fake_attn = torch.zeros(batch_size, config_without_qk_norm.num_attention_heads, seq_len, config_without_qk_norm.head_dim).to(torch.bfloat16) + attention.attn_func = MagicMock(return_value=fake_attn) + + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + out = attention(hidden, freqs_cis=freqs_cis) + + # Should complete successfully without QK norm + assert out.shape == (batch_size, seq_len, config_without_qk_norm.hidden_size) + + def test_forward_applies_partial_rotary_embedding(self, config, sdpa_backend): + """Test that rotary embedding is applied with partial_rotary_factor""" + config.partial_rotary_factor = 0.5 + attention = Glm4MoeAttention(config, sdpa_backend) + batch_size, seq_len = 1, 2 + hidden = torch.randn(batch_size, seq_len, config.hidden_size).to(torch.bfloat16) + rotary_dim = int(config.head_dim * config.partial_rotary_factor) + freqs_cis = torch.randn(batch_size, seq_len, rotary_dim) + + attention.attn_func = MagicMock( + return_value=torch.zeros(batch_size, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) + ) + + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb") as mock_rotary: + mock_rotary.side_effect = lambda x, *_: x + attention(hidden, freqs_cis=freqs_cis) + + # Should apply rotary to both q and k + assert mock_rotary.call_count == 2 + # Verify that cos and sin are split correctly based on partial_rotary_factor + for call_args in mock_rotary.call_args_list: + cos = call_args[0][1] + sin = call_args[0][2] + # cos and sin should be half of rotary_dim + assert cos.shape[-1] == rotary_dim // 2 + assert sin.shape[-1] == rotary_dim // 2 + + def test_forward_passes_preprocessed_kwargs(self, config, sdpa_backend): + attention = Glm4MoeAttention(config, sdpa_backend) + batch, seq_len = 1, 3 + hidden = torch.randn(batch, seq_len, config.hidden_size).to(torch.bfloat16) + freqs_cis = torch.randn(batch, seq_len, int(config.head_dim * config.partial_rotary_factor)) + attention_mask = torch.ones(batch, seq_len, dtype=torch.bool) + + fake_attn = torch.zeros(batch, config.num_attention_heads, seq_len, config.head_dim).to(torch.bfloat16) + attention.attn_func = MagicMock(return_value=fake_attn.to(torch.bfloat16)) + + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + attention(hidden, freqs_cis=freqs_cis, attention_mask=attention_mask) + + _, kwargs = attention.attn_func.call_args + assert kwargs.get("is_causal") is True + + def test_init_weights_resets_linears(self, config, sdpa_backend): + attention = Glm4MoeAttention(config, sdpa_backend) + + with patch("torch.nn.init.trunc_normal_") as mock_trunc: + attention.init_weights(torch.device("cpu"), init_std=0.05) + + # Should initialize q_proj, k_proj, v_proj, o_proj + assert mock_trunc.call_count == 4 + + def test_init_weights_resets_norms_when_qk_norm_enabled(self, config, sdpa_backend): + config.use_qk_norm = True + attention = Glm4MoeAttention(config, sdpa_backend) + + with patch("torch.nn.init.trunc_normal_") as mock_trunc, \ + patch.object(attention.q_norm, "reset_parameters") as mock_q_reset, \ + patch.object(attention.k_norm, "reset_parameters") as mock_k_reset: + attention.init_weights(torch.device("cpu"), init_std=0.05) + + assert mock_trunc.call_count == 4 + mock_q_reset.assert_called_once() + mock_k_reset.assert_called_once() + + def test_init_weights_skips_norms_when_qk_norm_disabled(self, config_without_qk_norm, sdpa_backend): + attention = Glm4MoeAttention(config_without_qk_norm, sdpa_backend) + + with patch("torch.nn.init.trunc_normal_") as mock_trunc: + attention.init_weights(torch.device("cpu"), init_std=0.05) + + # Should still initialize 4 linear layers + assert mock_trunc.call_count == 4 + + def test_forward_with_te_backend_supports_attention_mask(self, config, te_backend): + batch, seq_len = 1, 3 + fake_out = torch.zeros(batch, seq_len, config.num_attention_heads, config.head_dim) + fake_module = MagicMock() + fake_func = MagicMock(return_value=fake_out.to(torch.bfloat16)) + with patch( + "nemo_automodel.components.models.glm4_moe.layers.initialize_attn_module_and_func", + return_value=(fake_module, fake_func), + ): + attention = Glm4MoeAttention(config, te_backend) + + hidden = torch.randn(batch, seq_len, config.hidden_size).to(torch.bfloat16) + freqs_cis = torch.randn(batch, seq_len, int(config.head_dim * config.partial_rotary_factor)) + attention_mask = torch.tensor([[1, 0, 1]], dtype=torch.bool) + + with patch("nemo_automodel.components.models.glm4_moe.layers.apply_rotary_emb", side_effect=lambda x, *_: x): + attention(hidden, freqs_cis=freqs_cis, attention_mask=attention_mask) + + _, kwargs = attention.attn_func.call_args + assert "attention_mask" in kwargs + mask = kwargs["attention_mask"] + assert mask.shape == (batch, 1, 1, seq_len) + + def test_softmax_scale_matches_head_dim(self, config, sdpa_backend): + attention = Glm4MoeAttention(config, sdpa_backend) + keywords = getattr(attention.attn_func, "keywords", {}) or {} + scale = keywords.get("scale") + assert scale is not None + assert math.isclose(scale, config.head_dim ** -0.5, rel_tol=1e-6) + + def test_o_proj_has_no_bias(self, config, sdpa_backend): + """GLM4 MoE uses o_proj without bias""" + attention = Glm4MoeAttention(config, sdpa_backend) + assert attention.o_proj.bias is None diff --git a/tests/unit_tests/models/glm4_moe/test_glm4_moe_model.py b/tests/unit_tests/models/glm4_moe/test_glm4_moe_model.py new file mode 100644 index 000000000..a0ef6cffa --- /dev/null +++ b/tests/unit_tests/models/glm4_moe/test_glm4_moe_model.py @@ -0,0 +1,493 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest +import torch +from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + +from nemo_automodel.components.models.glm4_moe.model import Block, Glm4MoeForCausalLM, Glm4MoeModel +from nemo_automodel.components.moe.layers import MLP, MoE, MoEConfig +from nemo_automodel.components.moe.utils import BackendConfig + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +@pytest.fixture +def device(): + if torch.cuda.is_available(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + return torch.device("cpu") + + +@pytest.fixture +def glm_config(): + return Glm4MoeConfig( + vocab_size=256, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + num_hidden_layers=4, + intermediate_size=128, + moe_intermediate_size=64, + n_routed_experts=4, + n_shared_experts=1, + num_experts_per_tok=2, + first_k_dense_replace=2, # First 2 layers are dense, rest are MoE + max_position_embeddings=256, + rms_norm_eps=1e-6, + rope_theta=10000.0, + use_qk_norm=True, + partial_rotary_factor=0.5, + attention_bias=False, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + norm_topk_prob=False, + ) + + +@pytest.fixture +def backend_config(): + return BackendConfig( + linear="torch", + attn="sdpa", + rms_norm="torch", + enable_deepep=False, + fake_balanced_gate=False, + enable_hf_state_dict_adapter=False, + ) + + +@pytest.fixture +def moe_config(): + return MoEConfig( + dim=64, + inter_dim=128, + moe_inter_dim=64, + n_routed_experts=4, + n_shared_experts=1, + n_activated_experts=2, + n_expert_groups=1, + n_limited_groups=1, + train_gate=True, + gate_bias_update_factor=0.001, + score_func="sigmoid", + route_scale=1.0, + aux_loss_coeff=0.0, + norm_topk_prob=False, + expert_bias=False, + router_bias=False, + expert_activation="swiglu", + softmax_before_topk=False, + ) + + +class TestBlock: + def test_block_uses_mlp_for_dense_layers(self, glm_config, moe_config, backend_config): + """First k layers (< first_k_dense_replace) should use MLP""" + block = Block(layer_idx=0, config=glm_config, moe_config=moe_config, backend=backend_config) + + assert isinstance(block.mlp, MLP) + assert hasattr(block, "self_attn") + assert hasattr(block, "input_layernorm") + assert hasattr(block, "post_attention_layernorm") + + def test_block_uses_moe_for_sparse_layers(self, glm_config, moe_config, backend_config): + """Layers >= first_k_dense_replace should use MoE""" + block = Block(layer_idx=2, config=glm_config, moe_config=moe_config, backend=backend_config) + + assert isinstance(block.mlp, MoE) + + def test_block_stores_layer_idx(self, glm_config, moe_config, backend_config): + layer_idx = 3 + block = Block(layer_idx=layer_idx, config=glm_config, moe_config=moe_config, backend=backend_config) + + assert block.layer_idx == layer_idx + + def test_forward_pass_calls_attention_and_mlp(self, glm_config, backend_config, device): + block = Block(layer_idx=0, config=glm_config, moe_config=magic_moe_config(glm_config), backend=backend_config) + block = block.to(device) + + batch, seq_len = 2, 4 + x = torch.randn(batch, seq_len, glm_config.hidden_size, device=device) + freqs_cis = torch.randn(batch, seq_len, int(glm_config.head_dim * glm_config.partial_rotary_factor), device=device) + + with patch.object(block.self_attn, "forward", return_value=torch.zeros_like(x)) as mock_attn, \ + patch.object(block, "_mlp", return_value=torch.zeros_like(x)) as mock_mlp: + out = block(x, freqs_cis=freqs_cis) + + assert out.shape == x.shape + mock_attn.assert_called_once() + mock_mlp.assert_called_once() + + def test_forward_builds_padding_mask_from_attention(self, glm_config, backend_config, device): + block = Block(layer_idx=0, config=glm_config, moe_config=magic_moe_config(glm_config), backend=backend_config) + block = block.to(device) + + x = torch.randn(1, 3, glm_config.hidden_size, device=device) + freqs_cis = torch.randn(1, 3, int(glm_config.head_dim * glm_config.partial_rotary_factor), device=device) + attention_mask = torch.tensor([[1, 1, 0]], dtype=torch.bool, device=device) + + with patch.object(block.self_attn, "forward", return_value=torch.zeros_like(x)) as mock_attn, \ + patch.object(block, "_mlp", return_value=torch.zeros_like(x)) as mock_mlp: + block(x, freqs_cis=freqs_cis, attention_mask=attention_mask) + + mock_attn.assert_called_once() + _, kwargs = mock_mlp.call_args + padding_mask = kwargs.get("padding_mask") + assert padding_mask is not None + torch.testing.assert_close(padding_mask, attention_mask.logical_not()) + + def test_forward_uses_provided_padding_mask(self, glm_config, backend_config, device): + """Test that if padding_mask is provided, it's used directly""" + block = Block(layer_idx=0, config=glm_config, moe_config=magic_moe_config(glm_config), backend=backend_config) + block = block.to(device) + + x = torch.randn(1, 3, glm_config.hidden_size, device=device) + freqs_cis = torch.randn(1, 3, int(glm_config.head_dim * glm_config.partial_rotary_factor), device=device) + attention_mask = torch.tensor([[1, 1, 0]], dtype=torch.bool, device=device) + padding_mask = torch.tensor([[0, 0, 1]], dtype=torch.bool, device=device) + + with patch.object(block.self_attn, "forward", return_value=torch.zeros_like(x)) as mock_attn, \ + patch.object(block, "_mlp", return_value=torch.zeros_like(x)) as mock_mlp: + block(x, freqs_cis=freqs_cis, attention_mask=attention_mask, padding_mask=padding_mask) + + _, kwargs = mock_mlp.call_args + received_padding_mask = kwargs.get("padding_mask") + torch.testing.assert_close(received_padding_mask, padding_mask) + + def test_mlp_wrapper_handles_mlp_instance(self, glm_config, backend_config): + block = Block(layer_idx=0, config=glm_config, moe_config=magic_moe_config(glm_config), backend=backend_config) + x = torch.randn(2, 4, glm_config.hidden_size).to(torch.bfloat16) + + out = block._mlp(x, padding_mask=None) + + assert out.shape == x.shape + + def test_mlp_wrapper_handles_moe_instance(self, glm_config, backend_config): + block = Block(layer_idx=2, config=glm_config, moe_config=magic_moe_config(glm_config), backend=backend_config) + x = torch.randn(2, 4, glm_config.hidden_size).to(torch.bfloat16) + padding_mask = torch.zeros(2, 4, dtype=torch.bool) + + with patch.object(block.mlp, "forward", return_value=torch.zeros_like(x)) as mock_moe: + out = block._mlp(x, padding_mask=padding_mask) + + mock_moe.assert_called_once_with(x, padding_mask) + assert out.shape == x.shape + + def test_init_weights_resets_sublayers(self, glm_config, backend_config): + block = Block(layer_idx=0, config=glm_config, moe_config=magic_moe_config(glm_config), backend=backend_config) + + with patch.object(block.input_layernorm, "reset_parameters") as mock_in, \ + patch.object(block.post_attention_layernorm, "reset_parameters") as mock_post, \ + patch.object(block.self_attn, "init_weights") as mock_attn, \ + patch.object(block.mlp, "init_weights") as mock_mlp: + block.init_weights(torch.device("cpu")) + + mock_in.assert_called_once() + mock_post.assert_called_once() + mock_attn.assert_called_once() + mock_mlp.assert_called_once() + + +class TestGlm4MoeModel: + def test_model_initialization_sets_components(self, glm_config, backend_config): + model = Glm4MoeModel(glm_config, backend=backend_config) + + assert model.config == glm_config + assert model.backend == backend_config + assert len(model.layers) == glm_config.num_hidden_layers + assert model.embed_tokens.num_embeddings == glm_config.vocab_size + assert model.rotary_emb.head_dim == glm_config.head_dim + + def test_model_initializes_moe_config_with_sigmoid_scoring(self, glm_config, backend_config): + model = Glm4MoeModel(glm_config, backend=backend_config) + + assert hasattr(model, "moe_config") + assert model.moe_config.dim == glm_config.hidden_size + assert model.moe_config.n_routed_experts == glm_config.n_routed_experts + assert model.moe_config.n_shared_experts == glm_config.n_shared_experts + assert model.moe_config.n_activated_experts == glm_config.num_experts_per_tok + assert model.moe_config.score_func == "sigmoid" # GLM4 uses sigmoid + assert model.moe_config.softmax_before_topk is False + assert model.moe_config.route_scale == glm_config.routed_scaling_factor + + def test_model_initializes_moe_config_with_expert_groups(self, glm_config, backend_config): + model = Glm4MoeModel(glm_config, backend=backend_config) + + assert model.moe_config.n_expert_groups == glm_config.n_group + assert model.moe_config.n_limited_groups == glm_config.topk_group + + def test_model_accepts_custom_moe_config(self, glm_config, backend_config, moe_config): + model = Glm4MoeModel(glm_config, backend=backend_config, moe_config=moe_config) + + assert model.moe_config == moe_config + + def test_model_uses_partial_rotary_factor(self, glm_config, backend_config): + glm_config.partial_rotary_factor = 0.75 + model = Glm4MoeModel(glm_config, backend=backend_config) + + assert model.rotary_emb.partial_rotary_factor == 0.75 + + def test_forward_runs_all_layers(self, glm_config, backend_config): + model = Glm4MoeModel(glm_config, backend=backend_config) + + batch, seq_len = 2, 5 + input_ids = torch.randint(0, glm_config.vocab_size, (batch, seq_len)) + freqs_mock = MagicMock(return_value=(1.0, torch.ones(glm_config.head_dim // 2))) + + with patch.object(model.rotary_emb, "_compute_concentration_and_inv_freq", freqs_mock): + with patch.object(Block, "forward", side_effect=lambda *_, **__: torch.randn(batch, seq_len, glm_config.hidden_size)) as mock_block: + out = model(input_ids) + + assert out.shape == (batch, seq_len, glm_config.hidden_size) + assert mock_block.call_count == glm_config.num_hidden_layers + + def test_forward_generates_position_ids_if_not_provided(self, glm_config, backend_config): + model = Glm4MoeModel(glm_config, backend=backend_config) + batch, seq_len = 2, 4 + input_ids = torch.randint(0, glm_config.vocab_size, (batch, seq_len)) + + with patch.object(model.rotary_emb, "_compute_concentration_and_inv_freq", return_value=(1.0, torch.ones(glm_config.head_dim // 2))): + with patch.object(Block, "forward", side_effect=lambda *_, **kwargs: torch.randn(batch, seq_len, glm_config.hidden_size)): + with patch("nemo_automodel.components.models.glm4_moe.model.position_ids_to_freqs_cis") as mock_freqs: + mock_freqs.return_value = torch.randn(batch, seq_len, int(glm_config.head_dim * glm_config.partial_rotary_factor)) + out = model(input_ids) + + # Verify position_ids_to_freqs_cis was called + mock_freqs.assert_called_once() + call_args = mock_freqs.call_args + position_ids = call_args[0][1] + assert position_ids.shape == (batch, seq_len) + expected_pos_ids = torch.arange(0, seq_len).unsqueeze(0).expand(batch, -1) + torch.testing.assert_close(position_ids, expected_pos_ids) + + def test_forward_accepts_position_ids(self, glm_config, backend_config): + model = Glm4MoeModel(glm_config, backend=backend_config) + batch, seq_len = 1, 4 + input_ids = torch.randint(0, glm_config.vocab_size, (batch, seq_len)) + position_ids = torch.arange(seq_len).unsqueeze(0) + + with patch.object(model.rotary_emb, "_compute_concentration_and_inv_freq", return_value=(1.0, torch.ones(glm_config.head_dim // 2))): + with patch.object(Block, "forward", return_value=torch.zeros(batch, seq_len, glm_config.hidden_size)): + out = model(input_ids, position_ids=position_ids) + + assert out.shape == (batch, seq_len, glm_config.hidden_size) + + def test_forward_computes_freqs_cis_from_rotary_emb(self, glm_config, backend_config): + model = Glm4MoeModel(glm_config, backend=backend_config) + batch, seq_len = 1, 3 + input_ids = torch.randint(0, glm_config.vocab_size, (batch, seq_len)) + + with patch.object(model.rotary_emb, "_compute_concentration_and_inv_freq", return_value=(1.0, torch.ones(glm_config.head_dim // 2))): + with patch("nemo_automodel.components.models.glm4_moe.model.position_ids_to_freqs_cis") as mock_freqs: + mock_freqs.return_value = torch.randn(batch, seq_len, int(glm_config.head_dim * glm_config.partial_rotary_factor)) + with patch.object(Block, "forward", return_value=torch.zeros(batch, seq_len, glm_config.hidden_size)): + model(input_ids) + + mock_freqs.assert_called_once() + assert mock_freqs.call_args[0][0] == model.rotary_emb + + def test_init_weights_updates_embeddings_and_layers(self, glm_config, backend_config): + model = Glm4MoeModel(glm_config, backend=backend_config) + original = model.embed_tokens.weight.clone() + + with patch.object(model.norm, "reset_parameters") as mock_norm, \ + patch.object(Block, "init_weights") as mock_layer_init: + model.init_weights(torch.device("cpu")) + + mock_norm.assert_called_once() + assert not torch.equal(model.embed_tokens.weight, original) + assert mock_layer_init.call_count == glm_config.num_hidden_layers + + def test_init_weights_updates_rotary_emb_device(self, glm_config, backend_config): + model = Glm4MoeModel(glm_config, backend=backend_config) + device = torch.device("cpu") + + with patch.object(model.norm, "reset_parameters"), \ + patch.object(Block, "init_weights"): + model.init_weights(buffer_device=device) + + assert model.rotary_emb.device == device + + +class TestGlm4MoeForCausalLM: + def test_forward_returns_logits(self, glm_config, backend_config, device): + model = Glm4MoeForCausalLM(glm_config, backend=backend_config) + model = model.to(device) + + batch, seq_len = 2, 6 + input_ids = torch.randint(0, glm_config.vocab_size, (batch, seq_len), device=device) + + with patch.object(model.model, "forward", return_value=torch.randn(batch, seq_len, glm_config.hidden_size, device=device).to(torch.bfloat16)): + logits = model(input_ids) + + assert logits.shape == (batch, seq_len, glm_config.vocab_size) + + def test_forward_with_thd_format_squeezes_input(self, glm_config, backend_config, device): + model = Glm4MoeForCausalLM(glm_config, backend=backend_config) + model = model.to(device) + + batch, seq_len = 1, 5 + input_ids = torch.randint(0, glm_config.vocab_size, (batch, seq_len), device=device) + + with patch("nemo_automodel.components.models.glm4_moe.model.squeeze_input_for_thd") as mock_squeeze, \ + patch.object(model.model, "forward", return_value=torch.randn(seq_len, glm_config.hidden_size, device=device).to(torch.bfloat16)): + mock_squeeze.return_value = (input_ids.squeeze(0), None, None, {"qkv_format": "thd"}) + logits = model(input_ids, qkv_format="thd") + + mock_squeeze.assert_called_once() + # Output should be unsqueezed back to batch dimension + assert logits.shape == (batch, seq_len, glm_config.vocab_size) + + def test_initialize_weights_invokes_submodules(self, glm_config, backend_config): + model = Glm4MoeForCausalLM(glm_config, backend=backend_config) + original = model.lm_head.weight.clone() + + with patch.object(model.model, "init_weights") as mock_init: + model.initialize_weights(buffer_device=torch.device("cpu"), dtype=torch.float32) + + mock_init.assert_called_once() + assert not torch.equal(model.lm_head.weight, original) + assert model.lm_head.weight.dtype == torch.float32 + + def test_initialize_weights_uses_scaled_std_for_lm_head(self, glm_config, backend_config): + model = Glm4MoeForCausalLM(glm_config, backend=backend_config) + + with patch.object(model.model, "init_weights"), \ + patch("torch.nn.init.trunc_normal_") as mock_trunc: + model.initialize_weights(buffer_device=torch.device("cpu"), dtype=torch.float32) + + # Check that trunc_normal_ was called with scaled std + mock_trunc.assert_called() + call_args = mock_trunc.call_args + assert call_args[1]["std"] == glm_config.hidden_size ** -0.5 + + def test_initialize_weights_sets_e_score_correction_bias_for_moe_layers(self, glm_config, backend_config): + """GLM4 MoE initializes e_score_correction_bias for MoE layers""" + model = Glm4MoeForCausalLM(glm_config, backend=backend_config) + device = torch.device("cpu") + + with patch.object(model.model, "init_weights"): + model.initialize_weights(buffer_device=device, dtype=torch.float32) + + # Check that MoE layers (>= first_k_dense_replace) have e_score_correction_bias + for layer_idx, layer in enumerate(model.model.layers.values()): + if isinstance(layer.mlp, MoE): + assert layer_idx >= glm_config.first_k_dense_replace + assert hasattr(layer.mlp.gate, "e_score_correction_bias") + assert layer.mlp.gate.e_score_correction_bias.shape == (glm_config.n_routed_experts,) + assert layer.mlp.gate.e_score_correction_bias.dtype == torch.float32 + torch.testing.assert_close( + layer.mlp.gate.e_score_correction_bias, + torch.zeros(glm_config.n_routed_experts, dtype=torch.float32) + ) + + def test_initialize_weights_updates_rotary_emb_device_after_dtype_move(self, glm_config, backend_config): + model = Glm4MoeForCausalLM(glm_config, backend=backend_config) + device = torch.device("cpu") + + with patch.object(model.model, "init_weights"): + model.initialize_weights(buffer_device=device, dtype=torch.float32) + + assert model.model.rotary_emb.device == device + + def test_state_dict_adapter_created_when_enabled(self, glm_config, backend_config): + backend_config.enable_hf_state_dict_adapter = True + model = Glm4MoeForCausalLM(glm_config, backend=backend_config) + + assert hasattr(model, "state_dict_adapter") + + def test_state_dict_adapter_not_created_when_disabled(self, glm_config, backend_config): + backend_config.enable_hf_state_dict_adapter = False + model = Glm4MoeForCausalLM(glm_config, backend=backend_config) + + assert not hasattr(model, "state_dict_adapter") + + +class TestGlm4MoeModelClassmethods: + def test_from_config_creates_model(self, glm_config, backend_config): + model = Glm4MoeForCausalLM.from_config(glm_config, backend=backend_config) + + assert isinstance(model, Glm4MoeForCausalLM) + assert model.config == glm_config + assert model.backend == backend_config + + def test_from_pretrained_classmethod(self): + """Ensure classmethod from_pretrained builds config then delegates to from_config.""" + cfg = Glm4MoeConfig( + vocab_size=128, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=2, + intermediate_size=128, + head_dim=16, + n_routed_experts=2, + n_shared_experts=1, + num_experts_per_tok=1, + first_k_dense_replace=1, + moe_intermediate_size=64, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + norm_topk_prob=False, + use_qk_norm=True, + partial_rotary_factor=0.5, + attention_bias=False, + ) + + with patch("transformers.models.glm4_moe.configuration_glm4_moe.Glm4MoeConfig.from_pretrained") as mock_from_pretrained: + mock_from_pretrained.return_value = cfg + + with patch.object(Glm4MoeForCausalLM, "from_config", wraps=Glm4MoeForCausalLM.from_config) as mock_from_config: + model = Glm4MoeForCausalLM.from_pretrained("glm4_moe/model") + assert isinstance(model, Glm4MoeForCausalLM) + mock_from_pretrained.assert_called_once_with("glm4_moe/model") + called_cfg = mock_from_config.call_args[0][0] + assert called_cfg is cfg + + def test_modelclass_export_exists(self): + """Ensure ModelClass pointer is defined and points to class.""" + from nemo_automodel.components.models.glm4_moe import model as glm_mod + + assert hasattr(glm_mod, "ModelClass") + assert glm_mod.ModelClass is Glm4MoeForCausalLM + + +def magic_moe_config(config: Glm4MoeConfig) -> MoEConfig: + return MoEConfig( + dim=config.hidden_size, + inter_dim=config.intermediate_size, + moe_inter_dim=config.moe_intermediate_size, + n_routed_experts=config.n_routed_experts, + n_shared_experts=config.n_shared_experts, + n_activated_experts=config.num_experts_per_tok, + n_expert_groups=config.n_group, + n_limited_groups=config.topk_group, + train_gate=True, + gate_bias_update_factor=0.001, + score_func="sigmoid", + route_scale=config.routed_scaling_factor, + aux_loss_coeff=0.0, + norm_topk_prob=config.norm_topk_prob, + expert_bias=False, + router_bias=False, + expert_activation="swiglu", + softmax_before_topk=False, + ) From 392331f9423260e8f43b6ae77045234dddf1fe12 Mon Sep 17 00:00:00 2001 From: hemildesai Date: Thu, 23 Oct 2025 00:28:01 +0000 Subject: [PATCH 2/2] unit tests Signed-off-by: hemildesai --- .../components/utils/flops_utils.py | 86 +++++++++++++++++++ tests/unit_tests/utils/test_flops_utils.py | 18 ++++ 2 files changed, 104 insertions(+) diff --git a/nemo_automodel/components/utils/flops_utils.py b/nemo_automodel/components/utils/flops_utils.py index a3bcdba2b..6fbd95ad5 100644 --- a/nemo_automodel/components/utils/flops_utils.py +++ b/nemo_automodel/components/utils/flops_utils.py @@ -715,6 +715,90 @@ def gpt_oss_flops(config, gbs=1, seq_len=None): ) +def glm4_moe_flops(config, gbs=1, seq_len=None): + if seq_len is None: + seq_len = config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048 + + layers = config.num_hidden_layers + hs = config.hidden_size + attention_heads = config.num_attention_heads + query_groups = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else attention_heads + vocab_size = config.vocab_size + + # GLM4 MoE attention config + head_dim = getattr(config, "head_dim", hs // attention_heads) + query_projection_to_hidden_size_ratio = (head_dim * attention_heads) / hs + + # MoE config + ffn_hs = config.intermediate_size # for dense layers + moe_intermediate_size = config.moe_intermediate_size if hasattr(config, "moe_intermediate_size") else ffn_hs + moe_router_topk = config.num_experts_per_tok if hasattr(config, "num_experts_per_tok") else 1 + n_shared_experts = config.n_shared_experts if hasattr(config, "n_shared_experts") else 0 + first_k_dense_replace = config.first_k_dense_replace if hasattr(config, "first_k_dense_replace") else 0 + + causal_self_attn = True + hidden_size = hs + gated_linear_multiplier = 2 # SwiGLU + + # Attention flops for GQA (Qwen3-style) + attention_flops = ( + 3 + * 2 + * gbs + * layers + * seq_len + * hidden_size + * hidden_size + * query_projection_to_hidden_size_ratio + * ( + (query_groups / attention_heads * 2 + 1) # QKV gemm + + (seq_len / hidden_size * 2 * (0.5 if causal_self_attn else 1)) # attention + + 1 # attention proj gemm + ) + ) + + # MLP flops (DeepSeek V3-style MoE) + # Dense layers: first_k_dense_replace layers + dense_mlp_flops = ( + 3 * 2 * gbs * first_k_dense_replace * seq_len * hidden_size * (1 + gated_linear_multiplier) * ffn_hs + ) + + # MoE layers: (layers - first_k_dense_replace) layers + # Each MoE layer has: shared experts + routed experts (topk selected) + num_moe_layers = layers - first_k_dense_replace + + # Shared expert flops (always computed) + shared_expert_flops = ( + 3 + * 2 + * gbs + * num_moe_layers + * seq_len + * hidden_size + * (1 + gated_linear_multiplier) + * (moe_intermediate_size * n_shared_experts) + ) + + # Routed expert flops (topk selected) + routed_expert_flops = ( + 3 + * 2 + * gbs + * num_moe_layers + * seq_len + * hidden_size + * (1 + gated_linear_multiplier) + * (moe_intermediate_size * moe_router_topk) + ) + + mlp_flops = dense_mlp_flops + shared_expert_flops + routed_expert_flops + + # Vocab flops + vocab_flops = 3 * 2 * gbs * seq_len * hidden_size * vocab_size + + return attention_flops + mlp_flops + vocab_flops + + def get_flops_formula_for_hf_config(config: Any) -> Optional[Callable]: """ Get the appropriate FLOPs formula function for a given HuggingFace config. @@ -752,6 +836,8 @@ def get_flops_formula_for_hf_config(config: Any) -> Optional[Callable]: "DeepseekV3Config": deepseekv3_flops, # GPT-OSS "GptOssConfig": gpt_oss_flops, + # GLM4 MoE + "Glm4MoeConfig": glm4_moe_flops, # T5 family (encoder-decoder) "T5Config": transformer_flops, "MT5Config": transformer_flops, diff --git a/tests/unit_tests/utils/test_flops_utils.py b/tests/unit_tests/utils/test_flops_utils.py index 4bef785b1..17027d535 100644 --- a/tests/unit_tests/utils/test_flops_utils.py +++ b/tests/unit_tests/utils/test_flops_utils.py @@ -149,6 +149,23 @@ def _gpt_oss_cfg() -> SimpleNamespace: ) +def _glm4_moe_cfg() -> SimpleNamespace: + return SimpleNamespace( + hidden_size=4096, + num_hidden_layers=46, + num_attention_heads=96, + num_key_value_heads=8, + intermediate_size=10944, + vocab_size=151552, + moe_intermediate_size=1408, + num_experts_per_tok=8, + n_shared_experts=1, + n_routed_experts=128, + first_k_dense_replace=1, + max_position_embeddings=131072, + ) + + @pytest.mark.parametrize( "name, func, cfg_factory, kwargs, expected", [ @@ -161,6 +178,7 @@ def _gpt_oss_cfg() -> SimpleNamespace: ("bert", flops_utils.bert_flops, _bert_cfg, dict(gbs=1, seq_len=512), 361920724992), ("transformer", flops_utils.transformer_flops, _transformer_cfg, dict(gbs=1, seq_len=1024), 8363320541184), ("gpt_oss", flops_utils.gpt_oss_flops, _gpt_oss_cfg, dict(gbs=1, seq_len=1024), 7356800827392), + ("glm4_moe", flops_utils.glm4_moe_flops, _glm4_moe_cfg, dict(gbs=1, seq_len=2048), 120277337899008), ("deepseekv3_moonlight", flops_utils.deepseekv3_flops, _moonlight_16b_config, dict(gbs=1, seq_len=2048), 30625801175040), ("deepseekv3_dsv3", flops_utils.deepseekv3_flops, _deepseek_v3_config, dict(gbs=1, seq_len=1024), 233225179889664), ],