diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index a20fc4cb3..e503a057f 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -120,61 +120,109 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: class SplitGateUpWeightsTransform(PytorchTransform): """ - split fused Gate+Up weights and copy into the model + Split fused Gate+Up weights and copy into the model. + Handles both standard MoE models and GptOss models. For every transformer layer inside `model`: - • expects .experts.gate_up_proj in the *source* `sd` - • copies halves into - .experts.gate_proj <-- Gate [E,H,I] - .experts.up_proj <-- Up [E,H,I] + • expects .experts.gate_up_proj in the *source* `sd` + • copies halves into + .experts.gate_proj <-- Gate [E,H,I] + .experts.up_proj <-- Up [E,H,I] + + Handles both interleaved weights (GptOss) and concatenated weights (standard MoE). + Also handles bias terms when present. """ @classmethod def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: transformed = False model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__ - if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS: return model, transformed model_tmp = model.language_model if hasattr(model, "language_model") else model - num_layers = len(model_tmp.model.layers) delete_fused_key = True sd = model_tmp.state_dict() + for layer_idx in range(num_layers): + # Determine if this is a GptOss model or standard MoE model + is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp") + # ---- build the textual prefix once per layer ---------- - prefix = f"model.layers.{layer_idx}.feed_forward.experts." + if is_gpt_oss: + prefix = f"model.layers.{layer_idx}.mlp.experts." + experts = model_tmp.model.layers[layer_idx].mlp.experts + else: + prefix = f"model.layers.{layer_idx}.feed_forward.experts." + experts = model_tmp.model.layers[layer_idx].feed_forward.experts fused_key = prefix + "gate_up_proj" gate_key = prefix + "gate_proj" up_key = prefix + "up_proj" - # ---- split [E,H,2I] → two [E,H,I] tensors ---------------------- - fused = sd[fused_key] # [E, H, 2I] (no .weight here) + # Check if we have bias terms (GptOss case) + has_bias = fused_key + "_bias" in sd + if has_bias: + fused_bias_key = fused_key + "_bias" + gate_bias_key = gate_key + "_bias" + up_bias_key = up_key + "_bias" + + # ---- split weights based on model type ---------------------- + fused = sd[fused_key] # [E, H, 2I] E, H, two_I = fused.shape - ffn_dim = two_I // 2 - gate, up = fused.split(ffn_dim, dim=-1) # views – no copy - experts = model_tmp.model.layers[layer_idx].feed_forward.experts + if is_gpt_oss: + # For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...] + gate = fused[..., ::2] # [E, H, I] - even indices + up = fused[..., 1::2] # [E, H, I] - odd indices + else: + # For standard MoE, gate/up are concatenated: [gate, up] + ffn_dim = two_I // 2 + gate, up = fused.split(ffn_dim, dim=-1) # views – no copy + + # Copy weights to model experts.gate_proj.data.copy_(gate) experts.up_proj.data.copy_(up) + # Handle bias if present + if has_bias: + fused_bias = sd[fused_bias_key] # [E, 2I] + + if is_gpt_oss: + gate_bias = fused_bias[..., ::2] # [E, I] - even indices + up_bias = fused_bias[..., 1::2] # [E, I] - odd indices + else: + ffn_dim = fused_bias.shape[-1] // 2 + gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1) + + experts.gate_proj_bias.data.copy_(gate_bias) + experts.up_proj_bias.data.copy_(up_bias) + # ---- update the state-dict so load_state_dict sees the right keys sd[gate_key] = gate sd[up_key] = up + if has_bias: + sd[gate_bias_key] = gate_bias + sd[up_bias_key] = up_bias + + # Delete fused keys if delete_fused_key: del sd[fused_key] + if has_bias: + del sd[fused_bias_key] - logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") + logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") transformed = True if hasattr(model, "language_model"): model.language_model = model_tmp else: model = model_tmp + return model, transformed -VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"} +# Keep the existing list of supported models +VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"} diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 16767fbe2..69970e5d8 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch -from transformers.cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, HybridChunkedCache +from transformers.cache_utils import EncoderDecoderCache, HybridCache, HybridChunkedCache from QEfficient.customop import ( CtxGatherFunc, @@ -23,18 +23,142 @@ ) -class QEffDynamicCache(DynamicCache): - """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. +class QEffDynamicCache: + def __init__(self) -> None: + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. + @classmethod + def from_legacy_cache( + cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...] + ) -> "QEffDynamicCache": + """ + Converts a cache in the legacy cache format into an equivalent `Cache`. Used for + backward compatibility. + """ + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + # Directly populate the cache lists + cache.key_cache.append(key_states) + cache.value_cache.append(value_states) + return cache - - Optimized implementation for the Cloud AI 100 to reuse KV Cache. - - get the position_ids input using kwargs. - - Use custom Onnxscript ops to write optimized version to generate Onnx model. + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length - """ + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" + return None + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + """ + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + # Scatter + if batch_index is not None: + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ).clone() + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ).clone() + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply( + self.key_cache[layer_idx], position_ids, key_states + ).clone() + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], position_ids, value_states + ).clone() + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + # Gather + ctx_len = k_out.shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out def write_only(self, key_states, value_states, layer_idx, cache_kwargs): """ @@ -113,80 +237,6 @@ def read_only(self, layer_idx, cache_kwargs): v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - k_out, v_out = key_states, value_states - else: - position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs - - # Scatter - if batch_index is not None: - invalid_scatter_index = torch.iinfo(torch.int32).max - scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - - self.key_cache[layer_idx] = CtxScatterFuncCB.apply( - self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states - ) - - self.value_cache[layer_idx] = CtxScatterFuncCB.apply( - self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states - ) - else: - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], position_ids, value_states - ) - - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Gather - ctx_len = k_out.shape[2] - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) - else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) - v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - - return k_out, v_out - def update3D( self, key_states: torch.Tensor, @@ -488,3 +538,102 @@ def update( ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out + + +# This is a hack for now, until we get to merging this code with HybridCache class, +# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and +# ours are made to work with AIC +class QEffHybridCacheForGPTOSS: + def __init__(self, config, batch_size, max_cache_len, sliding_window_len): + self.max_cache_len = max_cache_len + self.batch_size = batch_size + self.sliding_window_len = sliding_window_len + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + @classmethod + def from_legacy_cache( + cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "HybridCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls( + config, + batch_size=past_key_values[0][0].shape[0], + max_cache_len=past_key_values[1][0].shape[2], + sliding_window_len=past_key_values[0][0].shape[2], + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or len(self.key_cache[layer_idx]) == 0 # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + sliding_window = cache_kwargs.get("sliding_window") + + if is_sliding_layer: + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window) + else: + kv_position_ids = position_ids + + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Original Gather + ctx_len = self.key_cache[layer_idx].shape[2] + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 72b7acd98..fa6bfd0f4 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -183,6 +183,7 @@ ] ) +# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc. DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} # Define a transformers layers to QEff layers dictionary diff --git a/QEfficient/transformers/models/gpt_oss/__init__.py b/QEfficient/transformers/models/gpt_oss/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py new file mode 100644 index 000000000..bc460fea6 --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -0,0 +1,711 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssConfig, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRotaryEmbedding, + repeat_kv, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs + +from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE + + +class QEffGptOssExperts(GptOssExperts): + def __qeff_init__(self): + self.gate_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.gate_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + + +class QEffGptOssMLP(GptOssMLP): + def alt_forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Gate and Up projections + gate = (hidden @ W_g) + b_g # [T, I] + up = (hidden @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out = (intermediate @ W_d) + b_d # [T, H] + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + # ------------------- Gather based, weights as activation approach --------------- + def forward_weights_as_activation(self, hidden_states): + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts + gate_up_proj = self.experts.gate_up_proj[router_indices.flatten()] + gate_up_proj_bias = self.experts.gate_up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Apply Chosen Experts (without routing weights first) + # expert_in = hidden_states.repeat_interleave(self.router.top_k, dim=0) + # expert_in = expert_in.view(-1, 1, self.experts.hidden_size) + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + gate_up = torch.bmm(expert_in, gate_up_proj) + gate_up_proj_bias.unsqueeze(1) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + + # Apply activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation (This is before on Llama4) + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + # ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections --------------- + def forward(self, hidden_states): + # print("Seperate Split, Up, Gate Projections") + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts (separate gate and up projections) + gate_proj = self.experts.gate_proj[router_indices.flatten()] + gate_proj_bias = self.experts.gate_proj_bias[router_indices.flatten()] + up_proj = self.experts.up_proj[router_indices.flatten()] + up_proj_bias = self.experts.up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + # Apply gate and up projections separately using bmm + gate = torch.bmm(expert_in, gate_proj) + gate_proj_bias.unsqueeze(1) + up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1) + + # Apply activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + # Down projection + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + def optimized_moe_forward(self, hidden_states: torch.Tensor): + B, S, H = hidden_states.shape + T = B * S + hidden_states = hidden_states.view(T, H) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + + # Top-k selection + top_w, selected_experts = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + # Creating experts mask and routing weights masked + awesome_experts_mask_1 = ( + torch.nn.functional.one_hot(selected_experts[:, 0], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_2 = ( + torch.nn.functional.one_hot(selected_experts[:, 1], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_3 = ( + torch.nn.functional.one_hot(selected_experts[:, 2], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_4 = ( + torch.nn.functional.one_hot(selected_experts[:, 3], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + + gateupout1 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout2 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout3 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout4 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + + # Gate and Up projections + gate = (hidden_states @ W_g) + b_g # [T, I] + up = (hidden_states @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + gateupout1 += torch.where(awesome_experts_mask_1[e], intermediate, torch.zeros_like(gateupout1)) + gateupout2 += torch.where(awesome_experts_mask_2[e], intermediate, torch.zeros_like(gateupout2)) + gateupout3 += torch.where(awesome_experts_mask_3[e], intermediate, torch.zeros_like(gateupout3)) + gateupout4 += torch.where(awesome_experts_mask_4[e], intermediate, torch.zeros_like(gateupout4)) + + concat_down = torch.zeros((self.router.top_k, T, H)) + concat_mask = torch.cat( + ( + awesome_experts_mask_1.unsqueeze(0), + awesome_experts_mask_2.unsqueeze(0), + awesome_experts_mask_3.unsqueeze(0), + awesome_experts_mask_4.unsqueeze(0), + ), + dim=0, + ) + + concat_gateout = torch.cat( + (gateupout1.unsqueeze(0), gateupout2.unsqueeze(0), gateupout3.unsqueeze(0), gateupout4.unsqueeze(0)), dim=0 + ) + + for e in range(self.experts.num_experts): + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Down projection + down_out = (concat_gateout @ W_d) + b_d # [T, H] + + concat_down += torch.where(concat_mask[:, e, :], down_out, torch.zeros_like(concat_down)) + + downout1, downout2, downout3, downout4 = concat_down[0], concat_down[1], concat_down[2], concat_down[3] + hidden_states = ( + downout1 * top_w[:, 0].unsqueeze(-1) + + downout2 * top_w[:, 1].unsqueeze(-1) + + downout3 * top_w[:, 2].unsqueeze(-1) + + downout4 * top_w[:, 3].unsqueeze(-1) + ).reshape(B, S, H) + + # original shape [B, S, H] + return hidden_states, router_logits + + +# Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology +class QEffGptOssRotaryEmbedding(GptOssRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: GptOssConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class QEffGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + # kv_seq_len = key_states.shape[-2] + + # kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffGptOssDecoderLayer(GptOssDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + # alth, _ = self.mlp.alt_forward(hidden_states) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=past_key_values.sliding_window_len, + sliding_window=past_key_values.sliding_window_len, + ) + + hidden_states = inputs_embeds + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class QEffGptOssForCausalLM(GptOssForCausalLM): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + + >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + + return MoeCausalLMOutputWithPast( + loss=None, + aux_loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def get_pkv_dynamic_axes( + self, + ): + pkv_dynamic_axes = [] + for layer_type in self.config.layer_types: + if layer_type == "sliding_attention": + pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"}) + elif layer_type == "full_attention": + pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"}) + return pkv_dynamic_axes diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 60b1c929d..7a084abcb 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -24,7 +24,6 @@ MistralForCausalLM, MistralModel, MistralRotaryEmbedding, - logger, repeat_kv, rotate_half, ) @@ -298,10 +297,6 @@ def forward( if use_cache and not isinstance(past_key_values, Cache) and not self.training: past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) return_legacy_cache = True - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" - ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2f3ee3dc0..470bf65d6 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1591,10 +1591,20 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names.append(f"past_{kv}.{i}_RetainedState") else: + # HACK: create common function for this including above if condition code + pkv_dynamic_axes = ( + self.model.get_pkv_dynamic_axes() if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes + ) + pkv_dynamic_axes = ( + [pkv_dynamic_axes] * self.model.config.num_hidden_layers + if isinstance(pkv_dynamic_axes, dict) + else pkv_dynamic_axes + ) + for i in range(self.num_layers): for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] output_names.append(f"past_{kv}.{i}_RetainedState") if self.continuous_batching: @@ -1841,6 +1851,11 @@ def compile( for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + # HACK for now + if self.model.config.model_type == "gpt_oss": + for spec in specializations: + spec.update({"sliding_window": 128}) + qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ca74c0ddd..20839590f 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -51,6 +51,15 @@ GPTBigCodeForCausalLM, GPTBigCodeModel, ) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRMSNorm, +) from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel from transformers.models.granite.modeling_granite import ( GraniteAttention, @@ -199,6 +208,14 @@ QEffGPTBigCodeForCausalLM, QEffGPTBigCodeModel, ) +from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( + QEffGptOssAttention, + QEffGptOssDecoderLayer, + QEffGptOssExperts, + QEffGptOssForCausalLM, + QEffGptOssMLP, + QEffGptOssModel, +) from QEfficient.transformers.models.gptj.modeling_gptj import ( QEffGPTJAttention, QEffGPTJBlock, @@ -338,6 +355,7 @@ class CustomOpsTransform(ModuleMappingTransform): MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, + GptOssRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, } @@ -399,6 +417,13 @@ class KVCacheTransform(ModuleMappingTransform): Gemma3TextModel: QEffGemma3TextModel, Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration, + # GPT_OSS + GptOssAttention: QEffGptOssAttention, + GptOssDecoderLayer: QEffGptOssDecoderLayer, + GptOssModel: QEffGptOssModel, + GptOssForCausalLM: QEffGptOssForCausalLM, + GptOssMLP: QEffGptOssMLP, + GptOssExperts: QEffGptOssExperts, # Granite GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 361be3080..fd81e6306 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -91,9 +91,13 @@ def prepare_pytorch_inputs(self): inputs["batch_index"] = torch.arange(1).view(-1, 1) past_key_values = [] + sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] for i in range(self.n_layer): - past_key = torch.zeros((self.padding_shape), dtype=torch.float32) - past_value = torch.zeros((self.padding_shape), dtype=torch.float32) + pad_shape = ( + sliding_padding_shape if self.config.layer_types[i] == "sliding_attention" else self.padding_shape + ) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) inputs["past_key_values"] = tuple(past_key_values) diff --git a/examples/gpt_oss.py b/examples/gpt_oss.py new file mode 100644 index 000000000..d33500f92 --- /dev/null +++ b/examples/gpt_oss.py @@ -0,0 +1,54 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +## BEFORE RUNNING PLS, RUN THE CONVERT SCRIPT TO CONVERT THE SAFETENSORS FROM FP4 to BF16 +## SEE DETAILS HERE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +## ONCE CONVERTED, PASS THE MODIFIED WEIGHTS TO THE MODEL_ID BELOW +import torch +from transformers import AutoConfig, GptOssForCausalLM, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils._utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants +from QEfficient.utils.run_utils import ApiRunner + +torch.manual_seed(42) +model_id = "CONVERTED_WEIGHTS" # See Comments above to convert saftensors to BF16 +config = AutoConfig.from_pretrained(model_id) + +model = GptOssForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float32, attn_implementation="eager", config=config +) +model.eval() + +tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) +config = model.config +batch_size = len(Constants.INPUT_STR) + +api_runner = ApiRunner(batch_size, tokenizer, config, Constants.INPUT_STR, Constants.PROMPT_LEN, Constants.CTX_LEN) + +qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False) +onnx_model_path = qeff_model.export() +qpc_path = qeff_model.compile( + prefill_seq_len=32, + ctx_len=256, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=4, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, +) +print(f"qpc path is {qpc_path}") +streamer = TextStreamer(tokenizer) +exec_info = qeff_model.generate( + tokenizer, + streamer=streamer, + prompts="Who is your creator? and What all you are allowed to do?", + device_ids=[0, 1, 2, 3], +) diff --git a/pyproject.toml b/pyproject.toml index 479736c22..5dfb00e63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,11 +19,11 @@ classifiers = [ ] requires-python = ">=3.8,<3.11" dependencies = [ - "transformers==4.51.3", - "huggingface-hub==0.30.0", + "transformers==4.55.0", + "huggingface-hub", "hf_transfer==0.1.9", - "peft==0.13.2", - "datasets==2.20.0", + "peft", + "datasets", "fsspec==2023.6.0", "multidict==6.0.4", "urllib3<2", diff --git a/tests/test_gpt.py b/tests/test_gpt.py new file mode 100644 index 000000000..8e44f2f82 --- /dev/null +++ b/tests/test_gpt.py @@ -0,0 +1,73 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +from transformers import AutoConfig, GptOssForCausalLM, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils._utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants +from QEfficient.utils.run_utils import ApiRunner + +Constants.INPUT_STR = [ + "Make sure tokens don't repeat\n\nTo make a simple cup of coffee, start by boiling water. Add one to two teaspoons of instant coffee powder to a mug. Pour the hot water over the coffee and stir well. Add sugar and milk to taste, if desired. For brewed coffee, use a French press or drip filter. Add coarsely ground coffee to the device, pour hot water over it, and let it steep for four minutes. Press or filter the coffee, then serve" +] + +torch.manual_seed(42) +model_id = "openai/gpt-oss-20b" +config = AutoConfig.from_pretrained(model_id) +config.num_hidden_layers = 2 + +# Remove the quantization_config attribute if it exists, to avoid MXFP4 Issues +if hasattr(config, "quantization_config"): + delattr(config, "quantization_config") + +model = GptOssForCausalLM.from_pretrained( + "/home/vbaddi/transformers/src/transformers/models/gpt_oss/new_weights", + torch_dtype=torch.float32, + attn_implementation="eager", + config=config, +) +model.eval() +model.generation_config.sample = False +tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) +config = model.config +batch_size = len(Constants.INPUT_STR) + +api_runner = ApiRunner(batch_size, tokenizer, config, Constants.INPUT_STR, 97, 256) +pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model) + + +qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False) +# pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + +onnx_model_path = qeff_model.export() + + +qpc_path = qeff_model.compile( + prefill_seq_len=128, + ctx_len=256, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, +) +print(f"qpc path is {qpc_path}") +streamer = TextStreamer(tokenizer) +exec_info = qeff_model.generate( + tokenizer, + streamer=streamer, + prompts=Constants.INPUT_STR[0], + device_ids=[0], +) + +print(pytorch_hf_tokens) +print(exec_info) +assert (exec_info.generated_ids[0][0, :159] == pytorch_hf_tokens).all()