Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions nemo_automodel/components/models/glm4_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
147 changes: 147 additions & 0 deletions nemo_automodel/components/models/glm4_moe/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
from torch import nn
from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig

from nemo_automodel.components.attention.utils import (
initialize_attn_module_and_func,
postprocess_output_for_attn,
preprocess_args_and_kwargs_for_attn,
)
from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb
from nemo_automodel.components.moe.utils import (
BackendConfig,
initialize_linear_module,
initialize_rms_norm_module,
)


class Glm4MoeAttention(nn.Module):
"""GLM4 MoE attention with optional query/key per-head RMSNorm + partial RoPE.

Key differences from Qwen3 MoE:
- Optional QK normalization (controlled by use_qk_norm)
- Partial rotary embeddings (controlled by partial_rotary_factor)
- o_proj bias is False (unlike Qwen3 which has configurable attention_bias)

Shapes:
- Input: x -> [B, S, H]
- Projections:
q: [B, S, n_heads, head_dim]
k/v: [B, S, n_kv_heads, head_dim] -> repeated to n_heads via groups
- Output: [B, S, H]
"""

def __init__(self, config: Glm4MoeConfig, backend: BackendConfig):
super().__init__()
self.backend = backend
self.config = config

self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
self.use_qk_norm = config.use_qk_norm
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)

self.q_proj = initialize_linear_module(
backend.linear, config.hidden_size, self.num_heads * self.head_dim, config.attention_bias
)
self.k_proj = initialize_linear_module(
backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias
)
self.v_proj = initialize_linear_module(
backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, config.attention_bias
)
self.o_proj = initialize_linear_module(
backend.linear, self.num_heads * self.head_dim, config.hidden_size, False
)

# Optional per-head RMSNorm for Q and K
if self.use_qk_norm:
self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps)
self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps)

# Attention implementation
softmax_scale = self.head_dim**-0.5
self.attn_module, self.attn_func = initialize_attn_module_and_func(
attn_impl=backend.attn,
num_attention_heads=self.num_heads,
num_qk_channels=self.head_dim,
num_v_channels=self.head_dim,
softmax_scale=softmax_scale,
num_gqa_groups=self.num_kv_heads,
)

def forward(
self,
x: torch.Tensor,
*,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) -> torch.Tensor:
if len(x.shape) == 2:
qkv_format = "thd"
num_tokens = x.shape[0]
else:
qkv_format = "bshd"
bsz, seqlen, _ = x.size()

q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

if qkv_format == "thd":
q = q.view(num_tokens, self.num_heads, self.head_dim)
k = k.view(num_tokens, self.num_kv_heads, self.head_dim)
v = v.view(num_tokens, self.num_kv_heads, self.head_dim)
else:
q = q.view(bsz, seqlen, self.num_heads, self.head_dim)
k = k.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.num_kv_heads, self.head_dim)

# Optional per-head RMSNorm
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)

# Partial RoPE (only apply to first partial_rotary_factor of head_dim)
rotary_dim = int(self.head_dim * self.partial_rotary_factor)
cos, sin = freqs_cis.split(rotary_dim // 2, dim=-1)

q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)

# Backend-specific attention
q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn(
q, k, v, attention_mask, self.backend.attn, **attn_kwargs
)
out = self.attn_func(q, k, v, **_attn_kwargs)
out = postprocess_output_for_attn(out, self.backend.attn)

flatten_dim = 2 if qkv_format == "bshd" else 1
out = self.o_proj(out.flatten(flatten_dim))
return out

def init_weights(self, buffer_device: torch.device, init_std: float = 0.02):
linear_list = [self.q_proj, self.k_proj, self.v_proj, self.o_proj]
for linear in linear_list:
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
if self.use_qk_norm:
for norm in (self.q_norm, self.k_norm):
norm.reset_parameters()
Loading
Loading