-
Notifications
You must be signed in to change notification settings - Fork 241
Add FP8 quantization example for Granite4 #1814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
1ea061d
granite4 moe quantization
chichun-charlie-liu e476bbc
device_map was missing, can't shard model without it
chichun-charlie-liu 8c1d9dc
placeholder for 3d tensor save func
chichun-charlie-liu 272fc1f
Add 2d-to-3d expert conversion before ckpt save
andrea-fasoli a47b845
Add notes
andrea-fasoli dc1a9b5
granite4_example and related codes clean-up
chichun-charlie-liu 1accba4
Merge pull request #1 from chichun-charlie-liu/save_3d_experts
chichun-charlie-liu b39d38c
Merge branch 'vllm-project:main' into main
chichun-charlie-liu 6b8b037
cleanup granite4_example.py and related codes
chichun-charlie-liu f050572
suggested by gemini code review
chichun-charlie-liu 5bafb22
further clean-up and simplification
chichun-charlie-liu d3f4127
fix for quality check and update dep ver suggestions
chichun-charlie-liu 92a5c17
Merge branch 'main' into main
dsikka 2bf60aa
remove device mapping in gr4_ex.py
chichun-charlie-liu d7f93b5
Merge branch 'main' into main
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from compressed_tensors.utils import replace_module | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( | ||
GraniteMoeHybridParallelExperts, | ||
) | ||
|
||
from llmcompressor import oneshot | ||
from llmcompressor.modeling.granite4 import GraniteMoeHybridParallelExpertsLinear | ||
from llmcompressor.modifiers.quantization import QuantizationModifier | ||
from llmcompressor.utils import dispatch_for_generation | ||
|
||
""" | ||
There are three "Linear-like" layers in Granite4's GraniteMoeHybridMoe (moe module) that | ||
could be quantized. Among the three layers, usually "router" should be kept in high | ||
precision, therefore, user could choose whether to quantize the other two layers, | ||
.input_linear and .output_linear. This example demonstrates the quantization of these | ||
"Linear-like" input/output layers with minimal changes in llm-compressor. | ||
|
||
Note that input_linear and output_linear are `GraniteMoeHybridParallelExperts`, which | ||
subclasses nn.Modules instead of nn.Linear, for it needs to store weights in 3D, i.e. | ||
[num_experts, out_feat, in_feat]. Because llm-compressor can only handle nn.Linear at | ||
the moment, our simple workaround would be: | ||
1. Swap `GraniteMoeHybridParallelExperts` with `GraniteMoeHybridParallelExpertsLinear` | ||
The custom class is equivalent to the original one, except it subclasses nn.Linear | ||
and stores 2D weights. Moe expert weight tensors will be converted from 3D to 2D, | ||
i.e. from [num_experts, out_feat, in_feat] to [num_experts * out_feat, in_feat]. | ||
2. Perform dynamic fp8 quantization | ||
The new class is compatible with typical per-channel weight quantization, | ||
llm-compressor will be able to identify those layers and process them normally. The | ||
resulting scales will have shape of [num_experts * out_feat, 1] | ||
3. Reshape weights and scales back to 3D before saving the checkpoint | ||
|
||
NOTE This checkpoint format will need latest vllm (ver >= 0.10.1.1) to run correctly. | ||
Test settings: | ||
1. DEP VERSION: vllm=0.10.1.1, lm_eval=0.4.9.1, flash-attn=2.7.3, torch=2.7.1 | ||
2. ENV VAR: VLLM_USE_V1=0, VLLM_WORKER_MULTIPROC_METHOD=spawn | ||
3. device: H100-80G | ||
|
||
Results: | ||
1. base model | ||
|
||
`lm_eval --model vllm --model_args pretrained=ibm-granite/granite-4.0-tiny-preview,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,enable_prefix_caching=False,max_model_len=8192 --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k` | ||
|
||
|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr| | ||
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:| | ||
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.602|± |0.0135| | ||
| | |strict-match | 5|exact_match|↑ |0.583|± |0.0136| | ||
|
||
2. FP8 version | ||
|
||
`lm_eval --model vllm --model_args pretrained=gr4_fp8_skipRouter_lin_exp3d,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,enable_prefix_caching=False,max_model_len=8192 --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k` | ||
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| | ||
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| | ||
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.6073|± |0.0135| | ||
| | |strict-match | 5|exact_match|↑ |0.5921|± |0.0135| | ||
|
||
|
||
If running with hf instead of vllm, such as the command below, there will be an error | ||
related to the "weight_scale" when the FP8 ckpt is being used. | ||
|
||
`lm_eval --model hf --model_args pretrained=ibm-granite/granite-4.0-tiny-preview,dtype=auto --batch_size 16 --trust_remote_code --tasks gsm8k` | ||
|
||
""" | ||
|
||
MODEL_ID = "ibm-granite/granite-4.0-tiny-preview" | ||
|
||
# Load model. | ||
model = AutoModelForCausalLM.from_pretrained( | ||
MODEL_ID, torch_dtype="bfloat16", device_map="auto" | ||
dsikka marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
) | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
||
skip_router_only = True # assume we want to quantize input/output moe layers | ||
|
||
if skip_router_only: | ||
for n, m in model.named_modules(): | ||
if isinstance(m, GraniteMoeHybridParallelExperts): | ||
new_mod = GraniteMoeHybridParallelExpertsLinear.from_3d_expert(m) | ||
replace_module(model, n, new_mod) | ||
print(f"Replaced {n}") | ||
ignore_lay = ["re:.*lm_head", "re:.*block_sparse_moe.router"] | ||
dsikka marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
SAVE_DIR = "gr4_fp8_skipRouter_lin_exp3d" | ||
else: | ||
# Skip all .input_linear, .output-linear, and router layers. | ||
ignore_lay = ["re:.*lm_head", "re:.*block_sparse_moe"] | ||
SAVE_DIR = "gr4_fp8_skipMoe_lin" | ||
|
||
recipe = QuantizationModifier( | ||
targets=["Linear", "GraniteMoeHybridParallelExpertsLinear"], | ||
scheme="FP8_DYNAMIC", | ||
ignore=ignore_lay, | ||
) | ||
|
||
# Apply quantization and save in compressed-tensors format. | ||
# NOTE Do NOT save the model using oneshot(..., output_dir=SAVE_DIR) here as it will | ||
# trigger conversion of weights from BF16 to FP8 and subsequently cause dtype mismatch | ||
# in the following generation test. For example, F.linear(x, W) in forward() will throw | ||
dsikka marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# errors as x is in BF16 but W is in FP8. | ||
oneshot(model=model, recipe=recipe) | ||
|
||
# Confirm generations of the quantized model look sane. | ||
print("After module swapping") | ||
print(model) | ||
dsikka marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
print("========== SAMPLE GENERATION ==============") | ||
dispatch_for_generation(model) | ||
input_ids = tokenizer( | ||
"What is your favorite TV show?", return_tensors="pt" | ||
).input_ids.to("cuda") | ||
output = model.generate(input_ids, max_new_tokens=20) | ||
print(tokenizer.decode(output[0])) | ||
print("==========================================") | ||
|
||
# Revert weights of MoE experts to 3D format (num_experts, output_size, input_size) | ||
dsikka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for n, m in model.named_modules(): | ||
if isinstance(m, GraniteMoeHybridParallelExpertsLinear): | ||
# NOTE: can assert type != "meta" instead, which is sign of offloading | ||
assert m.weight.device.type == "cuda", ( | ||
"Found some offloaded weights. This is not compatible with reshaping " | ||
"experts to 3D prior model save. Ensure the model is fully on cuda." | ||
) | ||
m.to_3d_expert() | ||
print(f"Updated experts of {n}") | ||
|
||
model.save_pretrained(SAVE_DIR) | ||
tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import torch | ||
from compressed_tensors.quantization import QuantizationStatus | ||
from compressed_tensors.utils import register_offload_parameter | ||
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( | ||
GraniteMoeHybridParallelExperts, | ||
) | ||
|
||
|
||
class GraniteMoeHybridParallelExpertsLinear(torch.nn.Linear): | ||
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: | ||
"""Use a real Linear so that llmcompressor and vllm can handle it easier. | ||
1. Change .weight from 3D [num_experts, output_size, input_size] to 2D | ||
[num_experts * output_size, input_size] before calling llm-compressor | ||
2. Change it back to 3D before saving ckpt | ||
""" | ||
super().__init__( | ||
input_size, output_size * num_experts, bias=False, device="meta" | ||
) | ||
self.num_experts = num_experts | ||
self.input_size = input_size | ||
self.output_size = output_size | ||
self.is_2d: bool = True | ||
|
||
@classmethod | ||
def from_3d_expert(cls, original: GraniteMoeHybridParallelExperts): | ||
"""Reshape weights of GraniteMoeHybridParallelExperts module into 2D and store | ||
them as weights of this "Linear" module. | ||
""" | ||
newMoeLin = cls(original.num_experts, original.input_size, original.output_size) | ||
newMoeLin.weight = torch.nn.Parameter( | ||
original.weight.view(-1, original.input_size).clone(), | ||
requires_grad=False, | ||
) | ||
original.to("cpu") | ||
newMoeLin.is_2d = True | ||
return newMoeLin | ||
|
||
def to_3d_expert(self) -> None: | ||
"""Convert weights and quantization parameters from 2D to 3D shape.""" | ||
dim0_mul = self.num_experts * self.output_size | ||
assert ( | ||
self.weight.shape == torch.Size((dim0_mul, self.input_size)) | ||
and hasattr(self, "weight_scale") | ||
and self.weight_scale.shape == torch.Size((dim0_mul, 1)) | ||
), "Shape mismatch, please check." | ||
|
||
self.weight = torch.nn.Parameter( | ||
self.weight.view( | ||
self.num_experts, self.output_size, self.input_size | ||
).clone(), | ||
requires_grad=False, | ||
) | ||
self.weight_scale = torch.nn.Parameter( | ||
self.weight_scale.view(self.num_experts, self.output_size, 1).clone(), | ||
requires_grad=False, | ||
) | ||
if hasattr(self, "weight_zero_point"): | ||
assert self.weight_zero_point.shape == torch.Size((dim0_mul, 1)) | ||
self.weight_zero_point = torch.nn.Parameter( | ||
self.weight_zero_point.view( | ||
self.num_experts, self.output_size, 1 | ||
).clone(), | ||
requires_grad=False, | ||
) | ||
self.is_2d = False | ||
|
||
def forward(self, inputs, expert_size): | ||
"""Modified from original forward()""" | ||
|
||
input_list = inputs.split(expert_size, dim=0) | ||
# consider the case of CompressedLinear | ||
if getattr(self, "quantization_status", None) == QuantizationStatus.COMPRESSED: | ||
dsikka marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
weight_data = self.compressor.decompress_module(self) | ||
param = torch.nn.Parameter( | ||
weight_data, dtype=torch.bfloat16, requires_grad=False | ||
) | ||
chichun-charlie-liu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
register_offload_parameter(self, "weight", param) | ||
|
||
self.quantization_status = QuantizationStatus.FROZEN | ||
|
||
weight_3d = self.weight.view( | ||
self.num_experts, self.output_size, self.input_size | ||
) | ||
output_list = [] | ||
for i in range(self.num_experts): | ||
output_list.append(torch.nn.functional.linear(input_list[i], weight_3d[i])) | ||
|
||
results = torch.cat(output_list, dim=0) | ||
return results | ||
|
||
def __repr__(self): | ||
if self.is_2d: | ||
sizes_str = f"(out={self.weight.shape[0]},in={self.weight.shape[1]})" | ||
else: | ||
sizes_str = ( | ||
f"(exp={self.weight.shape[0]},out={self.weight.shape[1]}," | ||
f"in={self.weight.shape[2]})" | ||
) | ||
return f"{self.__class__.__name__}{sizes_str}" | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.