Skip to content

Commit b0bd1b0

Browse files
authored
B 436918994 (#233)
* add option to replicate vae. Fix cross attn splash. * lint.
1 parent c0f89e6 commit b0bd1b0

File tree

7 files changed

+63
-49
lines changed

7 files changed

+63
-49
lines changed

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ weights_dtype: 'bfloat16'
4040
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4141
activations_dtype: 'bfloat16'
4242

43+
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
44+
replicate_vae: False
45+
4346
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4447
# Options are "DEFAULT", "HIGH", "HIGHEST"
4548
# fp32 activations and fp32 weights with HIGHEST will provide the best precision

src/maxdiffusion/max_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ def walk_and_upload_blobs(config, output_dir):
221221

222222

223223
def device_put_replicated(x, sharding):
224+
"""
225+
Although the name indiciates replication, this function can be used
226+
to also shard an array based on sharding.
227+
"""
224228
return jax.make_array_from_callback(x.shape, sharding, lambda index: x[index])
225229

226230

src/maxdiffusion/models/attention_flax.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,20 +166,25 @@ def _tpu_flash_attention(
166166
dtype: jnp.dtype = jnp.float32,
167167
) -> jax.Array:
168168
"""TPU Flash Attention"""
169-
170-
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
169+
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
170+
# Cross-attention where kv dims are much smaller due to encoder_hidden_states.
171+
# If kv seq_len is padded too much, it causes issues in attention calculations.
172+
if key.shape[1] != query.shape[1]:
173+
kv_max_block_size = key.shape[1]
174+
else:
175+
kv_max_block_size = q_max_block_size
171176
if flash_block_sizes:
172177
block_sizes = flash_block_sizes
173178
else:
174179
block_sizes = splash_attention_kernel.BlockSizes(
175-
block_q=min(max_block_size, query.shape[2]),
176-
block_kv_compute=min(max_block_size, key.shape[2]),
177-
block_kv=min(max_block_size, key.shape[2]),
178-
block_q_dkv=min(max_block_size, query.shape[2]),
179-
block_kv_dkv=min(max_block_size, key.shape[2]),
180-
block_kv_dkv_compute=min(max_block_size, query.shape[2]),
181-
block_q_dq=min(max_block_size, query.shape[2]),
182-
block_kv_dq=min(max_block_size, query.shape[2]),
180+
block_q=min(q_max_block_size, query.shape[2]),
181+
block_kv_compute=min(kv_max_block_size, key.shape[2]),
182+
block_kv=min(kv_max_block_size, key.shape[2]),
183+
block_q_dkv=min(q_max_block_size, query.shape[2]),
184+
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
185+
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
186+
block_q_dq=min(q_max_block_size, query.shape[2]),
187+
block_kv_dq=min(kv_max_block_size, query.shape[2]),
183188
)
184189

185190
num_fsdp_shards = mesh.shape["fsdp"]

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
220220
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
221221
for path, val in flax.traverse_util.flatten_dict(params).items():
222222
sharding = logical_state_sharding[path].value
223+
if config.replicate_vae:
224+
sharding = NamedSharding(mesh, P())
223225
state[path].value = device_put_replicated(val, sharding)
224226
state = nnx.from_flat_state(state)
225227

@@ -231,11 +233,11 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
231233
def get_basic_config(cls, dtype):
232234
rules = [
233235
qwix.QtRule(
234-
module_path='.*', # Apply to all modules
235-
weight_qtype=dtype,
236-
act_qtype=dtype,
236+
module_path=".*", # Apply to all modules
237+
weight_qtype=dtype,
238+
act_qtype=dtype,
237239
)
238-
]
240+
]
239241
return rules
240242

241243
@classmethod
@@ -247,17 +249,17 @@ def get_fp8_config(cls, quantization_calibration_method: str):
247249
"""
248250
rules = [
249251
qwix.QtRule(
250-
module_path='.*', # Apply to all modules
251-
weight_qtype=jnp.float8_e4m3fn,
252-
act_qtype=jnp.float8_e4m3fn,
253-
bwd_qtype=jnp.float8_e5m2,
254-
bwd_use_original_residuals=True,
255-
disable_channelwise_axes=True, # per_tensor calibration
256-
weight_calibration_method = quantization_calibration_method,
257-
act_calibration_method = quantization_calibration_method,
258-
bwd_calibration_method = quantization_calibration_method,
252+
module_path=".*", # Apply to all modules
253+
weight_qtype=jnp.float8_e4m3fn,
254+
act_qtype=jnp.float8_e4m3fn,
255+
bwd_qtype=jnp.float8_e5m2,
256+
bwd_use_original_residuals=True,
257+
disable_channelwise_axes=True, # per_tensor calibration
258+
weight_calibration_method=quantization_calibration_method,
259+
act_calibration_method=quantization_calibration_method,
260+
bwd_calibration_method=quantization_calibration_method,
259261
)
260-
]
262+
]
261263
return rules
262264

263265
@classmethod
@@ -286,7 +288,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline
286288

287289
batch_size = int(config.per_device_batch_size * jax.local_device_count())
288290
latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size)
289-
model_inputs= (latents, timesteps, prompt_embeds)
291+
model_inputs = (latents, timesteps, prompt_embeds)
290292
with mesh:
291293
quantized_model = qwix.quantize_model(model, q_rules, *model_inputs)
292294
max_logging.log("Qwix Quantization complete.")

src/maxdiffusion/pyconfig.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def wan_init(raw_keys):
142142
if "quantization" not in raw_keys:
143143
raise ValueError("Quantization type is not set when use_qwix_quantization is enabled.")
144144
elif raw_keys["quantization"] not in ["int8", "fp8", "fp8_full"]:
145-
raise ValueError(f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}")
145+
raise ValueError(
146+
f"Quantization type is not supported when use_qwix_quantization is enabled: {raw_keys['quantization']}"
147+
)
146148

147149
@staticmethod
148150
def calculate_global_batch_sizes(per_device_batch_size):

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
17+
from qwix import QtProvider
1618
import os
1719
import jax
1820
import jax.numpy as jnp
@@ -290,7 +292,7 @@ def test_get_qt_provider(self, mock_qt_rule):
290292
config_int8 = Mock(spec=HyperParameters)
291293
config_int8.use_qwix_quantization = True
292294
config_int8.quantization = "int8"
293-
provider_int8 = WanPipeline.get_qt_provider(config_int8)
295+
provider_int8: QtProvider = WanPipeline.get_qt_provider(config_int8)
294296
self.assertIsNotNone(provider_int8)
295297
mock_qt_rule.assert_called_once_with(
296298
module_path='.*',
@@ -305,11 +307,7 @@ def test_get_qt_provider(self, mock_qt_rule):
305307
config_fp8.quantization = "fp8"
306308
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
307309
self.assertIsNotNone(provider_fp8)
308-
mock_qt_rule.assert_called_once_with(
309-
module_path='.*',
310-
weight_qtype=jnp.float8_e4m3fn,
311-
act_qtype=jnp.float8_e4m3fn
312-
)
310+
self.assertEqual(provider_fp8.rules[0].kwargs["weight_qtype"], jnp.float8_e4m3fn)
313311

314312
# Case 4: Quantization enabled, type 'fp8_full'
315313
mock_qt_rule.reset_mock()
@@ -319,17 +317,7 @@ def test_get_qt_provider(self, mock_qt_rule):
319317
config_fp8_full.quantization_calibration_method = "absmax"
320318
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
321319
self.assertIsNotNone(provider_fp8_full)
322-
mock_qt_rule.assert_called_once_with(
323-
module_path='.*', # Apply to all modules
324-
weight_qtype=jnp.float8_e4m3fn,
325-
act_qtype=jnp.float8_e4m3fn,
326-
bwd_qtype=jnp.float8_e5m2,
327-
bwd_use_original_residuals=True,
328-
disable_channelwise_axes=True, # per_tensor calibration
329-
weight_calibration_method = config_fp8_full.quantization_calibration_method,
330-
act_calibration_method = config_fp8_full.quantization_calibration_method,
331-
bwd_calibration_method = config_fp8_full.quantization_calibration_method,
332-
)
320+
self.assertEqual(provider_fp8_full.rules[0].kwargs["bwd_qtype"], jnp.float8_e5m2)
333321

334322
# Case 5: Invalid quantization type
335323
config_invalid = Mock(spec=HyperParameters)
@@ -338,8 +326,8 @@ def test_get_qt_provider(self, mock_qt_rule):
338326
self.assertIsNone(WanPipeline.get_qt_provider(config_invalid))
339327

340328
# To test quantize_transformer, we patch its external dependencies
341-
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
342-
@patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs')
329+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
330+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs")
343331
def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model):
344332
"""
345333
Tests that quantize_transformer calls qwix when quantization is enabled.
@@ -370,14 +358,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
370358
# Check that the model returned is the new quantized model
371359
self.assertIs(result, mock_quantized_model_obj)
372360

373-
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
361+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
374362
def test_quantize_transformer_disabled(self, mock_quantize_model):
375363
"""
376364
Tests that quantize_transformer is skipped when quantization is disabled.
377365
"""
378366
# Setup Mocks
379367
mock_config = Mock(spec=HyperParameters)
380-
mock_config.use_qwix_quantization = False # Main condition for this test
368+
mock_config.use_qwix_quantization = False # Main condition for this test
381369

382370
mock_model = Mock(spec=WanModel)
383371

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from maxdiffusion.utils import load_video
3737
from skimage.metrics import structural_similarity as ssim
3838
from flax.training import train_state
39+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
3940

4041

4142
class TrainState(train_state.TrainState):
@@ -53,6 +54,12 @@ def generate_sample(config, pipeline, filename_prefix):
5354
"""
5455
Generates a video to validate training did not corrupt the model
5556
"""
57+
if not hasattr(pipeline, "vae"):
58+
wan_vae, vae_cache = WanPipeline.load_vae(
59+
pipeline.mesh.devices, pipeline.mesh, nnx.Rngs(jax.random.key(config.seed)), config
60+
)
61+
pipeline.vae = wan_vae
62+
pipeline.vae_cache = vae_cache
5663
return generate_wan(config, pipeline, filename_prefix)
5764

5865

@@ -141,10 +148,13 @@ def prepare_sample(features):
141148
def start_training(self):
142149

143150
pipeline = self.load_checkpoint()
144-
# del pipeline.vae
145-
146151
# Generate a sample before training to compare against generated sample after training.
147152
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
153+
154+
# save some memory.
155+
del pipeline.vae
156+
del pipeline.vae_cache
157+
148158
mesh = pipeline.mesh
149159
train_data_iterator = self.load_dataset(mesh, is_training=True)
150160

0 commit comments

Comments
 (0)