Skip to content

Commit aad9839

Browse files
authored
Fix wan qwix test (#235)
1 parent d028efa commit aad9839

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16-
17-
from qwix import QtProvider
1816
import os
1917
import jax
2018
import jax.numpy as jnp
@@ -292,7 +290,7 @@ def test_get_qt_provider(self, mock_qt_rule):
292290
config_int8 = Mock(spec=HyperParameters)
293291
config_int8.use_qwix_quantization = True
294292
config_int8.quantization = "int8"
295-
provider_int8: QtProvider = WanPipeline.get_qt_provider(config_int8)
293+
provider_int8 = WanPipeline.get_qt_provider(config_int8)
296294
self.assertIsNotNone(provider_int8)
297295
mock_qt_rule.assert_called_once_with(
298296
module_path='.*',
@@ -307,7 +305,11 @@ def test_get_qt_provider(self, mock_qt_rule):
307305
config_fp8.quantization = "fp8"
308306
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
309307
self.assertIsNotNone(provider_fp8)
310-
self.assertEqual(provider_fp8.rules[0].kwargs["weight_qtype"], jnp.float8_e4m3fn)
308+
mock_qt_rule.assert_called_once_with(
309+
module_path='.*',
310+
weight_qtype=jnp.float8_e4m3fn,
311+
act_qtype=jnp.float8_e4m3fn
312+
)
311313

312314
# Case 4: Quantization enabled, type 'fp8_full'
313315
mock_qt_rule.reset_mock()
@@ -317,7 +319,17 @@ def test_get_qt_provider(self, mock_qt_rule):
317319
config_fp8_full.quantization_calibration_method = "absmax"
318320
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
319321
self.assertIsNotNone(provider_fp8_full)
320-
self.assertEqual(provider_fp8_full.rules[0].kwargs["bwd_qtype"], jnp.float8_e5m2)
322+
mock_qt_rule.assert_called_once_with(
323+
module_path='.*', # Apply to all modules
324+
weight_qtype=jnp.float8_e4m3fn,
325+
act_qtype=jnp.float8_e4m3fn,
326+
bwd_qtype=jnp.float8_e5m2,
327+
bwd_use_original_residuals=True,
328+
disable_channelwise_axes=True, # per_tensor calibration
329+
weight_calibration_method = config_fp8_full.quantization_calibration_method,
330+
act_calibration_method = config_fp8_full.quantization_calibration_method,
331+
bwd_calibration_method = config_fp8_full.quantization_calibration_method,
332+
)
321333

322334
# Case 5: Invalid quantization type
323335
config_invalid = Mock(spec=HyperParameters)
@@ -326,8 +338,8 @@ def test_get_qt_provider(self, mock_qt_rule):
326338
self.assertIsNone(WanPipeline.get_qt_provider(config_invalid))
327339

328340
# To test quantize_transformer, we patch its external dependencies
329-
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
330-
@patch("maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs")
341+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
342+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs')
331343
def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model):
332344
"""
333345
Tests that quantize_transformer calls qwix when quantization is enabled.
@@ -358,14 +370,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
358370
# Check that the model returned is the new quantized model
359371
self.assertIs(result, mock_quantized_model_obj)
360372

361-
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
373+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
362374
def test_quantize_transformer_disabled(self, mock_quantize_model):
363375
"""
364376
Tests that quantize_transformer is skipped when quantization is disabled.
365377
"""
366378
# Setup Mocks
367379
mock_config = Mock(spec=HyperParameters)
368-
mock_config.use_qwix_quantization = False # Main condition for this test
380+
mock_config.use_qwix_quantization = False # Main condition for this test
369381

370382
mock_model = Mock(spec=WanModel)
371383

0 commit comments

Comments
 (0)