File tree Expand file tree Collapse file tree 3 files changed +36
-12
lines changed
modifiers/quantization/gptq
tests/llmcompressor/modifiers/quantization Expand file tree Collapse file tree 3 files changed +36
-12
lines changed Original file line number Diff line number Diff line change @@ -120,17 +120,6 @@ class GPTQModifier(Modifier, QuantizationMixin):
120
120
_hessians : Dict [torch .nn .Module , torch .Tensor ] = PrivateAttr (default_factory = dict )
121
121
_num_samples : Dict [torch .nn .Module , int ] = PrivateAttr (default_factory = dict )
122
122
123
- @field_validator ("sequential_update" , mode = "before" )
124
- def validate_sequential_update (cls , value : bool ) -> bool :
125
- if not value :
126
- warnings .warn (
127
- "`sequential_update=False` is no longer supported, setting "
128
- "sequential_update=True" ,
129
- DeprecationWarning ,
130
- )
131
-
132
- return True
133
-
134
123
def resolve_quantization_config (self ) -> QuantizationConfig :
135
124
config = super ().resolve_quantization_config ()
136
125
@@ -317,3 +306,14 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
317
306
if self .offload_hessians :
318
307
if module in self ._hessians : # may have been deleted in context
319
308
self ._hessians [module ] = self ._hessians [module ].to (device = "cpu" )
309
+
310
+ @field_validator ("sequential_update" , mode = "before" )
311
+ def validate_sequential_update (cls , value : bool ) -> bool :
312
+ if not value :
313
+ warnings .warn (
314
+ "`sequential_update=False` is no longer supported, setting "
315
+ "sequential_update=True" ,
316
+ DeprecationWarning ,
317
+ )
318
+
319
+ return True
Original file line number Diff line number Diff line change @@ -45,7 +45,13 @@ def __reduce__(self):
45
45
46
46
@classmethod
47
47
def __get_pydantic_core_schema__ (cls , _source_type , _handler ):
48
- return core_schema .no_info_plain_validator_function (cls .validate )
48
+ return core_schema .no_info_after_validator_function (
49
+ cls .validate ,
50
+ schema = core_schema .str_schema (),
51
+ serialization = core_schema .plain_serializer_function_ser_schema (
52
+ lambda v : str (v )
53
+ ),
54
+ )
49
55
50
56
@classmethod
51
57
def validate (cls , value : "Sentinel" ) -> "Sentinel" :
Original file line number Diff line number Diff line change @@ -142,3 +142,21 @@ def test_config_resolution(strategies, actorder):
142
142
for config_group in modifier .config_groups .values ():
143
143
if config_group .weights .strategy == "group" :
144
144
assert config_group .weights .actorder == actorder
145
+
146
+
147
+ @pytest .mark .parametrize (
148
+ "has_actorder,actorder,exp_actorder" ,
149
+ [
150
+ (False , "N/A" , "static" ),
151
+ (True , None , None ),
152
+ (True , "static" , "static" ),
153
+ (True , "group" , "group" ),
154
+ ],
155
+ )
156
+ def test_serialize_actorder (has_actorder , actorder , exp_actorder ):
157
+ if has_actorder :
158
+ modifier = GPTQModifier (targets = ["Linear" ], actorder = actorder )
159
+ else :
160
+ modifier = GPTQModifier (targets = ["Linear" ])
161
+
162
+ assert modifier .model_dump ()["actorder" ] == exp_actorder
You can’t perform that action at this time.
0 commit comments