Skip to content

Commit 503e9d6

Browse files
authored
Revert "Attention bug fixes, tokamax splash defaulting logic (#282)" (#287)
This reverts commit 4896870.
1 parent 37ac734 commit 503e9d6

File tree

5 files changed

+70
-110
lines changed

5 files changed

+70
-110
lines changed

docs/attention_blocks_flowchart.md

Lines changed: 0 additions & 30 deletions
This file was deleted.
-229 KB
Binary file not shown.

src/maxdiffusion/max_utils.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -501,26 +501,17 @@ def get_flash_block_sizes(config):
501501
"""Create custom flash attention BlockSizes."""
502502
flash_block_sizes = None
503503
if len(config.flash_block_sizes.keys()) > 0:
504-
attention_is_tokamax = "tokamax" in config.attention
505-
user_block_sizes:Dict[str, int] = config.flash_block_sizes
506-
if attention_is_tokamax:
507-
max_logging.log("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel."
508-
"Hence following flash block properties specified will be ignored:"
509-
f"block_q: {user_block_sizes['block_q']},"
510-
f"block_q_dq: {user_block_sizes.get('block_q_dq')},"
511-
f"block_kv_dq: {user_block_sizes.get('block_kv_dq')},"
512-
f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}"
513-
)
504+
use_fused_bwd_kernel = config.flash_block_sizes.get("use_fused_bwd_kernel", False)
514505
flash_block_sizes = splash_attention_kernel.BlockSizes(
515-
block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) if attention_is_tokamax else user_block_sizes["block_q"],
516-
block_kv_compute=user_block_sizes["block_kv_compute"],
517-
block_kv=user_block_sizes["block_kv"],
518-
block_q_dkv=user_block_sizes["block_q_dkv"],
519-
block_kv_dkv=user_block_sizes["block_kv_dkv"],
520-
block_kv_dkv_compute=user_block_sizes["block_kv_dkv_compute"],
521-
block_q_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_q_dq"),
522-
block_kv_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_kv_dq"),
523-
use_fused_bwd_kernel=True if attention_is_tokamax else value_or_none(user_block_sizes, "use_fused_bwd_kernel"),
506+
block_q=config.flash_block_sizes["block_q"],
507+
block_kv_compute=config.flash_block_sizes["block_kv_compute"],
508+
block_kv=config.flash_block_sizes["block_kv"],
509+
block_q_dkv=config.flash_block_sizes["block_q_dkv"],
510+
block_kv_dkv=config.flash_block_sizes["block_kv_dkv"],
511+
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
512+
block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"),
513+
block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"),
514+
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel"),
524515
)
525516
return flash_block_sizes
526517

src/maxdiffusion/models/attention_flax.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,14 @@ def _tpu_flash_attention(
234234
if flash_block_sizes and key.shape[1] == query.shape[1]:
235235
block_sizes = flash_block_sizes
236236
else:
237-
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
238237
block_sizes = splash_attention_kernel.BlockSizes(
239-
block_q=block_size_q,
238+
block_q=min(q_max_block_size, query.shape[2]),
240239
block_kv_compute=min(kv_max_block_size, key.shape[2]),
241240
block_kv=min(kv_max_block_size, key.shape[2]),
242-
block_q_dkv=block_size_q,
241+
block_q_dkv=min(q_max_block_size, query.shape[2]),
243242
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
244243
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
245-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
244+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq,
246245
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
247246
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
248247
)

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -179,69 +179,69 @@ def test_wan_block(self):
179179
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim))
180180

181181
dummy_temb = jnp.ones((batch_size, 6, dim))
182-
183-
wan_block = WanTransformerBlock(
184-
rngs=rngs,
185-
dim=dim,
186-
ffn_dim=ffn_dim,
187-
num_heads=num_heads,
188-
qk_norm=qk_norm,
189-
cross_attn_norm=cross_attn_norm,
190-
eps=eps,
191-
attention="flash",
192-
mesh=mesh,
193-
flash_block_sizes=flash_block_sizes,
194-
)
195-
with mesh:
182+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
183+
wan_block = WanTransformerBlock(
184+
rngs=rngs,
185+
dim=dim,
186+
ffn_dim=ffn_dim,
187+
num_heads=num_heads,
188+
qk_norm=qk_norm,
189+
cross_attn_norm=cross_attn_norm,
190+
eps=eps,
191+
attention="flash",
192+
mesh=mesh,
193+
flash_block_sizes=flash_block_sizes,
194+
)
196195
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
197196
assert dummy_output.shape == dummy_hidden_states.shape
198197

199198
def test_wan_attention(self):
200-
for attention_kernel in ["flash", "tokamax_flash"]:
201-
pyconfig.initialize(
202-
[
203-
None,
204-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
205-
f"attention={attention_kernel}"
206-
],
207-
unittest=True
199+
pyconfig.initialize(
200+
[
201+
None,
202+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
203+
],
204+
unittest=True,
205+
)
206+
config = pyconfig.config
207+
208+
batch_size = 1
209+
channels = 16
210+
frames = 21
211+
height = 90
212+
width = 160
213+
hidden_states_shape = (batch_size, frames, height, width, channels)
214+
dummy_hidden_states = jnp.ones(hidden_states_shape)
215+
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
216+
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
217+
218+
key = jax.random.key(0)
219+
rngs = nnx.Rngs(key)
220+
devices_array = create_device_mesh(config)
221+
222+
flash_block_sizes = get_flash_block_sizes(config)
223+
224+
mesh = Mesh(devices_array, config.mesh_axes)
225+
batch_size = 1
226+
query_dim = 5120
227+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
228+
attention = FlaxWanAttention(
229+
rngs=rngs,
230+
query_dim=query_dim,
231+
heads=40,
232+
dim_head=128,
233+
attention_kernel="flash",
234+
mesh=mesh,
235+
flash_block_sizes=flash_block_sizes,
208236
)
209-
config = pyconfig.config
210-
batch_size = 1
211-
channels = 16
212-
frames = 21
213-
height = 90
214-
width = 160
215-
hidden_states_shape = (batch_size, frames, height, width, channels)
216-
dummy_hidden_states = jnp.ones(hidden_states_shape)
217-
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
218-
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
219-
220-
key = jax.random.key(0)
221-
rngs = nnx.Rngs(key)
222-
devices_array = create_device_mesh(config)
223-
mesh = Mesh(devices_array, config.mesh_axes)
224-
batch_size = 1
225-
query_dim = 5120
226-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
227-
flash_block_sizes = get_flash_block_sizes(config)
228-
attention = FlaxWanAttention(
229-
rngs=rngs,
230-
query_dim=query_dim,
231-
heads=40,
232-
dim_head=128,
233-
attention_kernel=attention_kernel,
234-
mesh=mesh,
235-
flash_block_sizes=flash_block_sizes,
236-
)
237-
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
237+
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
238238

239-
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
240-
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
241-
dummy_output = attention(
242-
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
243-
)
244-
assert dummy_output.shape == dummy_hidden_states_shape
239+
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
240+
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
241+
dummy_output = attention(
242+
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
243+
)
244+
assert dummy_output.shape == dummy_hidden_states_shape
245245

246246
# dot product
247247
try:

0 commit comments

Comments
 (0)