Skip to content

Commit a7bc45d

Browse files
kylesayrsdsikka
andauthored
[Actorder] Fix GPTQ actorder serialization (#1818)
## Purpose ## * Fix serialization of actorder sentinel values ## Prerequisites ## * #1815 ## Changes ## * Write explicit serializer for actorder field * Move down deprecated field checker, since it's less important to read first ## Testing ## * Added test which fails without these changes --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 84575f2 commit a7bc45d

File tree

3 files changed

+36
-12
lines changed

3 files changed

+36
-12
lines changed

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,6 @@ class GPTQModifier(Modifier, QuantizationMixin):
120120
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)
121121
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
122122

123-
@field_validator("sequential_update", mode="before")
124-
def validate_sequential_update(cls, value: bool) -> bool:
125-
if not value:
126-
warnings.warn(
127-
"`sequential_update=False` is no longer supported, setting "
128-
"sequential_update=True",
129-
DeprecationWarning,
130-
)
131-
132-
return True
133-
134123
def resolve_quantization_config(self) -> QuantizationConfig:
135124
config = super().resolve_quantization_config()
136125

@@ -317,3 +306,14 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
317306
if self.offload_hessians:
318307
if module in self._hessians: # may have been deleted in context
319308
self._hessians[module] = self._hessians[module].to(device="cpu")
309+
310+
@field_validator("sequential_update", mode="before")
311+
def validate_sequential_update(cls, value: bool) -> bool:
312+
if not value:
313+
warnings.warn(
314+
"`sequential_update=False` is no longer supported, setting "
315+
"sequential_update=True",
316+
DeprecationWarning,
317+
)
318+
319+
return True

src/llmcompressor/sentinel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@ def __reduce__(self):
4545

4646
@classmethod
4747
def __get_pydantic_core_schema__(cls, _source_type, _handler):
48-
return core_schema.no_info_plain_validator_function(cls.validate)
48+
return core_schema.no_info_after_validator_function(
49+
cls.validate,
50+
schema=core_schema.str_schema(),
51+
serialization=core_schema.plain_serializer_function_ser_schema(
52+
lambda v: str(v)
53+
),
54+
)
4955

5056
@classmethod
5157
def validate(cls, value: "Sentinel") -> "Sentinel":

tests/llmcompressor/modifiers/quantization/test_base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,21 @@ def test_config_resolution(strategies, actorder):
142142
for config_group in modifier.config_groups.values():
143143
if config_group.weights.strategy == "group":
144144
assert config_group.weights.actorder == actorder
145+
146+
147+
@pytest.mark.parametrize(
148+
"has_actorder,actorder,exp_actorder",
149+
[
150+
(False, "N/A", "static"),
151+
(True, None, None),
152+
(True, "static", "static"),
153+
(True, "group", "group"),
154+
],
155+
)
156+
def test_serialize_actorder(has_actorder, actorder, exp_actorder):
157+
if has_actorder:
158+
modifier = GPTQModifier(targets=["Linear"], actorder=actorder)
159+
else:
160+
modifier = GPTQModifier(targets=["Linear"])
161+
162+
assert modifier.model_dump()["actorder"] == exp_actorder

0 commit comments

Comments
 (0)