-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Adds Native GPTQ Layer, Runtime Execution, and Serialization Support #21641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21641 +/- ##
==========================================
- Coverage 82.55% 82.53% -0.03%
==========================================
Files 571 571
Lines 57626 57900 +274
Branches 9001 9056 +55
==========================================
+ Hits 47572 47785 +213
- Misses 7759 7800 +41
- Partials 2295 2315 +20
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
8e0ac65
to
904d0a0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
This is a first pass. I have a high level question and now I'm having second thoughts about the config
parameter in quantize
.
- What is supposed to come from
mode
? All the information needed to quantize or not? - How are
GPTQConfig
andGPTQDTypePolicy
supposed to work together? Do we need both?
targets.append(self.bias) | ||
targets.extend(getattr(self, name) for name in MODE_SPEC[mode]) | ||
|
||
for i, variable in enumerate(targets): | ||
variable.assign(store[str(i)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unfortunate that we didn't keep the names and instead assigned numbers.
What was the reason @james77777778 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables are built or modified in two ways when considering quantization, which is the reason:
- Directly built from the
__init__
usingdtype
(an instance ofDTypePolicyMap
). - Modified during a call to
quantize
inkeras.Model
The order of variables from these two methods is different. That's why we need to sort the variables within load|save_own_variables
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if we kept the string keys in the dictionary, like "kernel_scale"
, the order wouldn't matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for using integers is to maintain consistency with other layers. If we change the keys, older models will break.
But if we kept the string keys in the dictionary, like
"kernel_scale"
, the order wouldn't matter.
This is true, and I agree that this is a better solution for the saving/loading process of quantization.
Maybe we can find a way to do this while remaining compatible with old models?
I am okay with removing the GPTQConfig in support of independent kwargs that can be propagated down to the
|
keras/src/layers/core/dense.py
Outdated
@@ -352,6 +408,65 @@ def _int8_build(self, kernel_shape): | |||
trainable=False, | |||
) | |||
|
|||
def _gptq_build(self, kernel_shape, config): | |||
self.gptq = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's very unintuitive that the result of _gptq_build()
would be to "turn off" gptq.
Either we need a better name than self.gptq
here, or at least a comment explaining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed it to is_gptq_calibrated
. Added a comment for explanation too.
if mode == "gptq": | ||
policy_name = config.dtype_policy_string() | ||
policy = dtype_policies.get( | ||
f"{policy_name}_from_{self.dtype_policy.name}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That will give you gptq_from_...
. Don't you need the other parameters? qptq/.../..._from_...
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The line right above this generates the policy name from the config object in the correct format.
if mode == "gptq":
policy_name = config.dtype_policy_string()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reviews @hertschuh! I've addressed your comments.
if mode == "gptq": | ||
policy_name = config.dtype_policy_string() | ||
policy = dtype_policies.get( | ||
f"{policy_name}_from_{self.dtype_policy.name}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The line right above this generates the policy name from the config object in the correct format.
if mode == "gptq":
policy_name = config.dtype_policy_string()
keras/src/layers/core/dense.py
Outdated
@@ -352,6 +408,65 @@ def _int8_build(self, kernel_shape): | |||
trainable=False, | |||
) | |||
|
|||
def _gptq_build(self, kernel_shape, config): | |||
self.gptq = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed it to is_gptq_calibrated
. Added a comment for explanation too.
Description of the change
This PR introduces a dedicated
GPTQDTypePolicy
and integrates GPTQ as a first-class quantization mode in Keras. It adds build/call paths for"gptq"
inDense
andEinsumDense
, persists GPTQ parameters (quantized weights, scales, zero points, group indices), and wires the model-levelquantize()
flow to allocate GPTQ variables, execute the GPTQ pass, and rebuild execution graphs as needed.The core GPTQ quantizer is modified to return quantized weights and all required group parameters in a single pass. Serialization is supported by recording
weight_bits
andgroup_size
in the dtype policy (and inDTypePolicyMap
entries when used), so GPTQ-quantized KerasHub models reload correctly when using the Keras model format.Features
keras.dtype_policies.GPTQDTypePolicy
mode="gptq"
,weight_bits
, andgroup_size
.Layer-level GPTQ support (inference path)
Dense
andEinsumDense
now include_gptq_build(...)
and_gptq_call(...)
.dequantize_with_sz_map()
to reconstruct weights on the fly from:quantized_kernel: uint8
kernel_scale: float32
kernel_zero: uint8
g_idx: int32
Model integration
Model.quantize("gptq", config=GPTQConfig(...))
now:quantize(..., config=...)
to allocate GPTQ variables and set dtype policy (withweight_bits
/group_size
baked in).gptq_quantize(self, config)
).Serialization
load_own_weights(...)
andsave_own_weights(...)
methods inDense
/EinsumDense
to support serialization of GPTQ-quantized weights, scales, zeros, and group indices.Quantization utilities
quantize_with_sz_map()
anddequantize_with_sz_map()
for group-mapped (de)quantization.Benchmarks
Native GPTQ benchmark: colab notebook
Results: benchmarks_native_gptq.csv
Existing Simulated GPTQ benchmark: colab notebook
Results: benchmarks_simulated_gptq.csv
Complete comparative analysis
Benchmark Summary
Absolute Means for Context
Per-model Comparison
Notes:
Testing
dequantize_with_sz_map
.Follow-up Work
linalg.inverse
tocholeskly
/cholesky_inverse
(which was added in Added cholesky inverse operation to all the backends #21554). This should be a trivial to implement in future changes.