@@ -179,69 +179,69 @@ def test_wan_block(self):
179179 dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , dim ))
180180
181181 dummy_temb = jnp .ones ((batch_size , 6 , dim ))
182-
183- wan_block = WanTransformerBlock (
184- rngs = rngs ,
185- dim = dim ,
186- ffn_dim = ffn_dim ,
187- num_heads = num_heads ,
188- qk_norm = qk_norm ,
189- cross_attn_norm = cross_attn_norm ,
190- eps = eps ,
191- attention = "flash" ,
192- mesh = mesh ,
193- flash_block_sizes = flash_block_sizes ,
194- )
195- with mesh :
182+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
183+ wan_block = WanTransformerBlock (
184+ rngs = rngs ,
185+ dim = dim ,
186+ ffn_dim = ffn_dim ,
187+ num_heads = num_heads ,
188+ qk_norm = qk_norm ,
189+ cross_attn_norm = cross_attn_norm ,
190+ eps = eps ,
191+ attention = "flash" ,
192+ mesh = mesh ,
193+ flash_block_sizes = flash_block_sizes ,
194+ )
196195 dummy_output = wan_block (dummy_hidden_states , dummy_encoder_hidden_states , dummy_temb , dummy_rotary_emb )
197196 assert dummy_output .shape == dummy_hidden_states .shape
198197
199198 def test_wan_attention (self ):
200- for attention_kernel in ["flash" , "tokamax_flash" ]:
201- pyconfig .initialize (
202- [
203- None ,
204- os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
205- f"attention={ attention_kernel } "
206- ],
207- unittest = True
199+ pyconfig .initialize (
200+ [
201+ None ,
202+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
203+ ],
204+ unittest = True ,
205+ )
206+ config = pyconfig .config
207+
208+ batch_size = 1
209+ channels = 16
210+ frames = 21
211+ height = 90
212+ width = 160
213+ hidden_states_shape = (batch_size , frames , height , width , channels )
214+ dummy_hidden_states = jnp .ones (hidden_states_shape )
215+ wan_rot_embed = WanRotaryPosEmbed (attention_head_dim = 128 , patch_size = [1 , 2 , 2 ], max_seq_len = 1024 )
216+ dummy_rotary_emb = wan_rot_embed (dummy_hidden_states )
217+
218+ key = jax .random .key (0 )
219+ rngs = nnx .Rngs (key )
220+ devices_array = create_device_mesh (config )
221+
222+ flash_block_sizes = get_flash_block_sizes (config )
223+
224+ mesh = Mesh (devices_array , config .mesh_axes )
225+ batch_size = 1
226+ query_dim = 5120
227+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
228+ attention = FlaxWanAttention (
229+ rngs = rngs ,
230+ query_dim = query_dim ,
231+ heads = 40 ,
232+ dim_head = 128 ,
233+ attention_kernel = "flash" ,
234+ mesh = mesh ,
235+ flash_block_sizes = flash_block_sizes ,
208236 )
209- config = pyconfig .config
210- batch_size = 1
211- channels = 16
212- frames = 21
213- height = 90
214- width = 160
215- hidden_states_shape = (batch_size , frames , height , width , channels )
216- dummy_hidden_states = jnp .ones (hidden_states_shape )
217- wan_rot_embed = WanRotaryPosEmbed (attention_head_dim = 128 , patch_size = [1 , 2 , 2 ], max_seq_len = 1024 )
218- dummy_rotary_emb = wan_rot_embed (dummy_hidden_states )
219-
220- key = jax .random .key (0 )
221- rngs = nnx .Rngs (key )
222- devices_array = create_device_mesh (config )
223- mesh = Mesh (devices_array , config .mesh_axes )
224- batch_size = 1
225- query_dim = 5120
226- with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
227- flash_block_sizes = get_flash_block_sizes (config )
228- attention = FlaxWanAttention (
229- rngs = rngs ,
230- query_dim = query_dim ,
231- heads = 40 ,
232- dim_head = 128 ,
233- attention_kernel = attention_kernel ,
234- mesh = mesh ,
235- flash_block_sizes = flash_block_sizes ,
236- )
237- dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
237+ dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
238238
239- dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
240- dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
241- dummy_output = attention (
242- hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
243- )
244- assert dummy_output .shape == dummy_hidden_states_shape
239+ dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
240+ dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
241+ dummy_output = attention (
242+ hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
243+ )
244+ assert dummy_output .shape == dummy_hidden_states_shape
245245
246246 # dot product
247247 try :
0 commit comments