13
13
See the License for the specific language governing permissions and
14
14
limitations under the License.
15
15
"""
16
-
17
- from qwix import QtProvider
18
16
import os
19
17
import jax
20
18
import jax .numpy as jnp
@@ -292,7 +290,7 @@ def test_get_qt_provider(self, mock_qt_rule):
292
290
config_int8 = Mock (spec = HyperParameters )
293
291
config_int8 .use_qwix_quantization = True
294
292
config_int8 .quantization = "int8"
295
- provider_int8 : QtProvider = WanPipeline .get_qt_provider (config_int8 )
293
+ provider_int8 = WanPipeline .get_qt_provider (config_int8 )
296
294
self .assertIsNotNone (provider_int8 )
297
295
mock_qt_rule .assert_called_once_with (
298
296
module_path = '.*' ,
@@ -307,7 +305,11 @@ def test_get_qt_provider(self, mock_qt_rule):
307
305
config_fp8 .quantization = "fp8"
308
306
provider_fp8 = WanPipeline .get_qt_provider (config_fp8 )
309
307
self .assertIsNotNone (provider_fp8 )
310
- self .assertEqual (provider_fp8 .rules [0 ].kwargs ["weight_qtype" ], jnp .float8_e4m3fn )
308
+ mock_qt_rule .assert_called_once_with (
309
+ module_path = '.*' ,
310
+ weight_qtype = jnp .float8_e4m3fn ,
311
+ act_qtype = jnp .float8_e4m3fn
312
+ )
311
313
312
314
# Case 4: Quantization enabled, type 'fp8_full'
313
315
mock_qt_rule .reset_mock ()
@@ -317,7 +319,17 @@ def test_get_qt_provider(self, mock_qt_rule):
317
319
config_fp8_full .quantization_calibration_method = "absmax"
318
320
provider_fp8_full = WanPipeline .get_qt_provider (config_fp8_full )
319
321
self .assertIsNotNone (provider_fp8_full )
320
- self .assertEqual (provider_fp8_full .rules [0 ].kwargs ["bwd_qtype" ], jnp .float8_e5m2 )
322
+ mock_qt_rule .assert_called_once_with (
323
+ module_path = '.*' , # Apply to all modules
324
+ weight_qtype = jnp .float8_e4m3fn ,
325
+ act_qtype = jnp .float8_e4m3fn ,
326
+ bwd_qtype = jnp .float8_e5m2 ,
327
+ bwd_use_original_residuals = True ,
328
+ disable_channelwise_axes = True , # per_tensor calibration
329
+ weight_calibration_method = config_fp8_full .quantization_calibration_method ,
330
+ act_calibration_method = config_fp8_full .quantization_calibration_method ,
331
+ bwd_calibration_method = config_fp8_full .quantization_calibration_method ,
332
+ )
321
333
322
334
# Case 5: Invalid quantization type
323
335
config_invalid = Mock (spec = HyperParameters )
@@ -326,8 +338,8 @@ def test_get_qt_provider(self, mock_qt_rule):
326
338
self .assertIsNone (WanPipeline .get_qt_provider (config_invalid ))
327
339
328
340
# To test quantize_transformer, we patch its external dependencies
329
- @patch (" maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model" )
330
- @patch (" maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs" )
341
+ @patch (' maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model' )
342
+ @patch (' maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs' )
331
343
def test_quantize_transformer_enabled (self , mock_get_dummy_inputs , mock_quantize_model ):
332
344
"""
333
345
Tests that quantize_transformer calls qwix when quantization is enabled.
@@ -358,14 +370,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
358
370
# Check that the model returned is the new quantized model
359
371
self .assertIs (result , mock_quantized_model_obj )
360
372
361
- @patch (" maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model" )
373
+ @patch (' maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model' )
362
374
def test_quantize_transformer_disabled (self , mock_quantize_model ):
363
375
"""
364
376
Tests that quantize_transformer is skipped when quantization is disabled.
365
377
"""
366
378
# Setup Mocks
367
379
mock_config = Mock (spec = HyperParameters )
368
- mock_config .use_qwix_quantization = False # Main condition for this test
380
+ mock_config .use_qwix_quantization = False # Main condition for this test
369
381
370
382
mock_model = Mock (spec = WanModel )
371
383
0 commit comments