Skip to content

Conversation

JyotinderSingh
Copy link
Collaborator

@JyotinderSingh JyotinderSingh commented Sep 8, 2025

Description of the change

This PR introduces a dedicated GPTQDTypePolicy and integrates GPTQ as a first-class quantization mode in Keras. It adds build/call paths for "gptq" in Dense and EinsumDense, persists GPTQ parameters (quantized weights, scales, zero points, group indices), and wires the model-level quantize() flow to allocate GPTQ variables, execute the GPTQ pass, and rebuild execution graphs as needed.

The core GPTQ quantizer is modified to return quantized weights and all required group parameters in a single pass. Serialization is supported by recording weight_bits and group_size in the dtype policy (and in DTypePolicyMap entries when used), so GPTQ-quantized KerasHub models reload correctly when using the Keras model format.

Features

keras.dtype_policies.GPTQDTypePolicy

  • Encodes mode="gptq", weight_bits, and group_size.

Layer-level GPTQ support (inference path)

  • Dense and EinsumDense now include _gptq_build(...) and _gptq_call(...).
  • Runtime forward pass uses dequantize_with_sz_map() to reconstruct weights on the fly from:
    • quantized_kernel: uint8
    • kernel_scale: float32
    • kernel_zero: uint8
    • g_idx: int32
  • Group-wise Quantization/Dequantization happens purely with vectorized operations to minimize overhead.

Model integration

  • Model.quantize("gptq", config=GPTQConfig(...)) now:
    1. Invokes layer-local quantize(..., config=...) to allocate GPTQ variables and set dtype policy (with weight_bits/group_size baked in).
    2. Runs the GPTQ pass (gptq_quantize(self, config)).
    3. Rebuilds training/eval/predict functions when the graph is modified.

Serialization

  • Modified load_own_weights(...) and save_own_weights(...) methods in Dense/EinsumDense to support serialization of GPTQ-quantized weights, scales, zeros, and group indices.

Quantization utilities

  • Public helpers quantize_with_sz_map() and dequantize_with_sz_map() for group-mapped (de)quantization.

Benchmarks

Complete comparative analysis

Benchmark Summary

Metric (native − simulated) Mean change What it means
Quantization time −4.69% Native calibrates a bit faster overall.
Quantization CPU peak −0.43% Roughly the same CPU peak during calibration.
Quantization GPU peak −31.40% Native uses far less GPU memory while quantizing.
Disk reduction vs FP32 -49.67% Native actually shrinks checkpoints; simulated does not.
Inference VRAM −48.49% About half the VRAM at inference with native.
First-token latency +12.18% Native is slightly slower to first token (runtime dequant cost).
Throughput −1.40% Throughput is roughly a wash overall (model-dependent).

Absolute Means for Context

Metric Simulated (mean) Native (mean)
Inference VRAM (GB) 2.611 1.254
Quant GPU peak (GB) 3.989 3.093
First-token latency (ms) 40.13 42.78
Throughput (tokens/s) 419.53 406.05
Quant time (sec) 662.44 593.84
Avg. accuracy degradation vs FP (pct) 8.04 5.31
Avg. disk reduction vs FP32 0.0% 49.67%

Per-model Comparison

Model (preset) d_PPL_sim % d_PPL_nat % Disk GB (sim → nat) VRAM GB (sim → nat) First-token (ms) sim → nat Throughput (t/s) sim → nat
gpt2_causal_lm (gpt2_base_en_cnn_dailymail) 1.11 0.98 0.465 → 0.232 0.582 → 0.343 38.64 → 38.92 402.63 → 483.40
opt_causal_lm (opt_125m_en) 11.83 9.97 0.468 → 0.235 0.656 → 0.348 44.51 → 47.49 813.74 → 686.00
bloom_causal_lm (bloom_1.1b_multi) 5.23 7.03 3.985 → 2.113 4.995 → 2.296 61.66 → 62.79 156.00 → 131.55
gemma3_causal_lm (gemma3_1b) 14.00 3.24 3.730 → 1.808 4.210 → 2.029 15.72 → 21.93 305.75 → 323.26

Notes:

  1. Lower d_PPL is better.
  2. Disk "sim" shows no meaningful reduction vs FP32 overall; the big cuts are purely from the native layout.
  3. Throughput impact varies by model; Gemma3/GPT-2 improved, OPT/BLOOM dipped a bit. First-token is consistently a touch slower with native due to on-the-fly dequant.

Testing

  • Added tests to verify that GPTQ-quantized layers can be serialized/deserialized and rebuilt with correct shapes and params.
  • Added tests to exercise full quantization path: layer quantize → Hessian update → GPTQ pass → state cleanup.
  • Added tests to validate per-column equivalence via dequantize_with_sz_map.
  • Added tests to ensure activation permutation is undone (ordered vs unordered produce identical final weights).
  • Expanded combination coverage: group-wise, activation-order, symmetric, per-channel, and 8-bit variants.

Follow-up Work

  • Support for 2-bit and 4-bit integer packing will be added in subsequent PRs.
  • We can explore activation quantization to enable low-precision GEMM speedups.
  • To reduce numeric variance and reduce surface for bugs in scope for this PR, we haven't migrated linalg.inverse to choleskly/cholesky_inverse (which was added in Added cholesky inverse operation to all the backends #21554). This should be a trivial to implement in future changes.

@codecov-commenter
Copy link

codecov-commenter commented Sep 8, 2025

Codecov Report

❌ Patch coverage is 80.25890% with 61 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.53%. Comparing base (efb24b2) to head (bf2bd8d).
⚠️ Report is 8 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/layers/core/einsum_dense.py 73.19% 14 Missing and 12 partials ⚠️
keras/src/layers/core/dense.py 80.76% 8 Missing and 7 partials ⚠️
keras/src/dtype_policies/dtype_policy.py 66.66% 8 Missing and 3 partials ⚠️
keras/src/quantizers/gptq.py 90.69% 3 Missing and 1 partial ⚠️
keras/src/models/model.py 78.57% 2 Missing and 1 partial ⚠️
...ras/api/_tf_keras/keras/dtype_policies/__init__.py 0.00% 1 Missing ⚠️
keras/src/layers/layer.py 80.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21641      +/-   ##
==========================================
- Coverage   82.55%   82.53%   -0.03%     
==========================================
  Files         571      571              
  Lines       57626    57900     +274     
  Branches     9001     9056      +55     
==========================================
+ Hits        47572    47785     +213     
- Misses       7759     7800      +41     
- Partials     2295     2315      +20     
Flag Coverage Δ
keras 82.33% <80.25%> (-0.03%) ⬇️
keras-jax 63.49% <80.25%> (-0.03%) ⬇️
keras-numpy 57.84% <53.72%> (-0.02%) ⬇️
keras-openvino 34.38% <14.88%> (-0.01%) ⬇️
keras-tensorflow 64.20% <80.25%> (-0.02%) ⬇️
keras-torch 63.69% <74.11%> (-0.05%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@JyotinderSingh JyotinderSingh changed the title Adds runtime hooks for GPTQ End-to-End GPTQ layer support, Runtime Execution, and Serialization Sep 10, 2025
@JyotinderSingh JyotinderSingh changed the title End-to-End GPTQ layer support, Runtime Execution, and Serialization Adds Native GPTQ Layer, Runtime Execution, and Serialization Support Sep 10, 2025
@github-actions github-actions bot added the Gemma Gemma model specific issues label Sep 10, 2025
Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

This is a first pass. I have a high level question and now I'm having second thoughts about the config parameter in quantize.

  • What is supposed to come from mode? All the information needed to quantize or not?
  • How are GPTQConfig and GPTQDTypePolicy supposed to work together? Do we need both?

targets.append(self.bias)
targets.extend(getattr(self, name) for name in MODE_SPEC[mode])

for i, variable in enumerate(targets):
variable.assign(store[str(i)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unfortunate that we didn't keep the names and instead assigned numbers.

What was the reason @james77777778 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variables are built or modified in two ways when considering quantization, which is the reason:

  1. Directly built from the __init__ using dtype (an instance of DTypePolicyMap).
  2. Modified during a call to quantize in keras.Model

The order of variables from these two methods is different. That's why we need to sort the variables within load|save_own_variables.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if we kept the string keys in the dictionary, like "kernel_scale", the order wouldn't matter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for using integers is to maintain consistency with other layers. If we change the keys, older models will break.

But if we kept the string keys in the dictionary, like "kernel_scale", the order wouldn't matter.

This is true, and I agree that this is a better solution for the saving/loading process of quantization.
Maybe we can find a way to do this while remaining compatible with old models?

@JyotinderSingh
Copy link
Collaborator Author

JyotinderSingh commented Sep 11, 2025

Thanks for the PR!

This is a first pass. I have a high level question and now I'm having second thoughts about the config parameter in quantize.

  • What is supposed to come from mode? All the information needed to quantize or not?
  • How are GPTQConfig and GPTQDTypePolicy supposed to work together? Do we need both?

I am okay with removing the GPTQConfig in support of independent kwargs that can be propagated down to the quantize(...) methods of the layers. But maybe we should take that up in a subsequent PR (since this one is already quite large).

  1. This bit of information is redundant, largely kept to stay consistent with existing convention of specifying a mode parameter in the quantize() call.
  2. You don't need both at the same time. When quantizing a full-precision model you provide the GPTQConfig, which provides all the necessary information required to (i) Build the layer variables (ii) Run the GPTQ quantization process. However, when loading a serialized model, the config object is not present. In that case we use the information serialized as a part of the GPTQDTypePolicy to restore the layer. The benefit of this approach is that KerasHub models already have a sophisticated dtype propagation mechanism in place, which ensures each layer gets restored with the correct (quantized) dtype.

@@ -352,6 +408,65 @@ def _int8_build(self, kernel_shape):
trainable=False,
)

def _gptq_build(self, kernel_shape, config):
self.gptq = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very unintuitive that the result of _gptq_build() would be to "turn off" gptq.

Either we need a better name than self.gptq here, or at least a comment explaining.

Copy link
Collaborator Author

@JyotinderSingh JyotinderSingh Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed it to is_gptq_calibrated. Added a comment for explanation too.

if mode == "gptq":
policy_name = config.dtype_policy_string()
policy = dtype_policies.get(
f"{policy_name}_from_{self.dtype_policy.name}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will give you gptq_from_.... Don't you need the other parameters? qptq/.../..._from_...?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line right above this generates the policy name from the config object in the correct format.

if mode == "gptq":
    policy_name = config.dtype_policy_string()

Copy link
Collaborator Author

@JyotinderSingh JyotinderSingh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reviews @hertschuh! I've addressed your comments.

if mode == "gptq":
policy_name = config.dtype_policy_string()
policy = dtype_policies.get(
f"{policy_name}_from_{self.dtype_policy.name}"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line right above this generates the policy name from the config object in the correct format.

if mode == "gptq":
    policy_name = config.dtype_policy_string()

@@ -352,6 +408,65 @@ def _int8_build(self, kernel_shape):
trainable=False,
)

def _gptq_build(self, kernel_shape, config):
self.gptq = False
Copy link
Collaborator Author

@JyotinderSingh JyotinderSingh Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed it to is_gptq_calibrated. Added a comment for explanation too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting review Gemma Gemma model specific issues size:L
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants