Skip to content

Commit f1ff3cc

Browse files
authored
Revert "Cross self attention switch (#251)" (#288)
This reverts commit 65c4e40.
1 parent 503e9d6 commit f1ff3cc

26 files changed

+347
-542
lines changed

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
5959
- name: PyTest
6060
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
61-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=65472" python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
6262
# add_pull_ready:
6363
# if: github.ref != 'refs/heads/main'
6464
# permissions:

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
__pycache__/
55
*.py[cod]
66
*$py.class
7+
78
# C extensions
89
*.so
910

@@ -97,7 +98,6 @@ celerybeat-schedule
9798

9899
# Environments
99100
.env
100-
.history
101101
.venv
102102
env/
103103
venv/

preview-xpk.sh

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#!/bin/bash
2+
bash docker_build_dependency_image.sh
3+
docker tag maxdiffusion_base_image:latest gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
4+
docker push gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
5+
CLUSTER_NAME=bodaborg-tpu7x-128
6+
DEVICE_TYPE=tpu7x-128 # can change to any size <= tpu7x-256
7+
PROJECT=cloud-tpu-multipod-dev
8+
ZONE=us-central1
9+
10+
# Please change the RUN_NAME and OUTPUT_DIR to your own GCS bucket path.
11+
export RUN_NAME=sanbao-wan-v7x-20k-${RANDOM}
12+
OUTPUT_DIR=gs://sanbao-bucket/wan/${RUN_NAME}
13+
# OUTPUT_DIR=gs://sanbao-bucket/wan/sanbao-wan-train-test
14+
DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/train/
15+
EVAL_DATA_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/eval_timesteps/
16+
SAVE_DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/save/
17+
RANDOM=123456789
18+
IMAGE_DIR=gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest
19+
# IMAGE_DIR=gcr.io/tpu-prod-env-multipod/maxdiffusion_jax_stable_stack_nightly@sha256:fd27d49a3be7f743f08e3b6b03e5ae00196794944310e3fee2a7795b99d81195
20+
LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl
21+
22+
xpk workload create \
23+
--cluster=$CLUSTER_NAME \
24+
--project=$PROJECT \
25+
--zone=$ZONE \
26+
--device-type=$DEVICE_TYPE \
27+
--num-slices=1 \
28+
--command=" \
29+
pip install . && \
30+
gsutil cp gs://libtpu-tpu7x-releases/wheels/libtpu/${LIBTPU_VERSION} . && \
31+
python -m pip install ${LIBTPU_VERSION} && \
32+
export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true \
33+
--xla_tpu_enable_async_collective_fusion=true \
34+
--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
35+
--xla_enable_async_all_reduce=true \
36+
--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \
37+
--xla_max_concurrent_async_all_gathers=4 \
38+
--xla_tpu_enable_async_all_to_all=true \
39+
--xla_latency_hiding_scheduler_rerun=5 \
40+
--xla_tpu_rwb_fusion=false \
41+
--xla_tpu_enable_sublane_major_scaling_bitcast_fusion=false \
42+
--xla_tpu_impure_enable_packed_bf16_math_ops=false \
43+
--xla_tpu_enable_sparse_core_reduce_scatter_v2=true \
44+
--xla_tpu_enable_sparse_core_collective_offload_all_gather=true \
45+
--xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \
46+
--xla_tpu_enable_all_gather_offload_tracing=true \
47+
--xla_tpu_use_tc_device_shape_on_sc=true \
48+
--xla_tpu_prefer_async_allgather_to_allreduce=true \
49+
--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \
50+
--xla_tpu_scoped_vmem_limit_kib=65536 \
51+
--xla_tpu_enable_tpu_custom_call_scoped_vmem_adjustments=true \
52+
--xla_enable_transpose_trace=false' && \
53+
echo 'Starting WAN training ...' && \
54+
HF_HUB_CACHE=/dev/shm python src/maxdiffusion/train_wan.py \
55+
src/maxdiffusion/configs/base_wan_14b.yml \
56+
attention='flash' \
57+
weights_dtype=bfloat16 \
58+
activations_dtype=bfloat16 \
59+
guidance_scale=5.0 \
60+
flow_shift=5.0 \
61+
fps=16 \
62+
skip_jax_distributed_system=False \
63+
run_name='test-wan-training-new' \
64+
output_dir=${OUTPUT_DIR} \
65+
train_data_dir=${DATASET_DIR} \
66+
load_tfrecord_cached=True \
67+
height=1280 \
68+
width=720 \
69+
num_frames=81 \
70+
num_inference_steps=50 \
71+
prompt='a japanese pop star young woman with black hair is singing with a smile. She is inside a studio with dim lighting and musical instruments.' \
72+
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
73+
enable_profiler=True \
74+
dataset_save_location=${SAVE_DATASET_DIR} \
75+
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
76+
flash_min_seq_length=0 \
77+
seed=$RANDOM \
78+
skip_first_n_steps_for_profiler=3 \
79+
profiler_steps=3 \
80+
per_device_batch_size=0.5 \
81+
ici_data_parallelism=64 \
82+
ici_fsdp_parallelism=2 \
83+
ici_tensor_parallelism=1 \
84+
allow_split_physical_axes=True \
85+
max_train_steps=150 \
86+
scan_layers=true \
87+
flash_block_sizes='{\"block_q\":2048,\"block_kv_compute\":512,\"block_kv\":2048,\"block_q_dkv\":2048,\"block_kv_dkv\":2048,\"block_kv_dkv_compute\":512,\"use_fused_bwd_kernel\":true}' \
88+
" \
89+
--base-docker-image=${IMAGE_DIR} \
90+
--enable-debug-logs \
91+
--workload=${RUN_NAME} \
92+
--priority=medium \
93+
--max-restarts=0

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ ftfy
1313
tensorboard>=2.17.0
1414
tensorboardx>=2.6.2.2
1515
tensorboard-plugin-profile>=2.15.2
16-
tokamax
1716
Jinja2
1817
scikit-image
1918
parameterized

src/maxdiffusion/common_types.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,7 @@
3333
BlockSizes = splash_attention_kernel.BlockSizes
3434

3535
AxisNames = tuple[str, ...]
36-
# Physical axis names for device meshes.
37-
DATA = "data"
38-
FSDP = "fsdp"
39-
TENSOR = "tensor"
40-
# Logical axis names for model parameters and activations.
36+
4137
BATCH = "activation_batch"
4238
LENGTH = "activation_length"
4339
KV_LENGTH = "activation_kv_length"
@@ -48,32 +44,4 @@
4844
KEEP_2 = "activation_keep_2"
4945
CONV_OUT = "activation_conv_out_channels"
5046

51-
# For setting self/cross attention independently in splash kernel
52-
SELF_ATTN_HEAD = "activation_self_attn_heads"
53-
SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
54-
SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
55-
CROSS_ATTN_HEAD = "activation_cross_attn_heads"
56-
CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
57-
CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"
58-
59-
6047
WAN_MODEL = "Wan2.1"
61-
62-
### Common axis rules for ring attention ###
63-
RING_ATTENTION_AXIS_RULES = [
64-
[SELF_ATTN_HEAD, None],
65-
[SELF_ATTN_Q_LENGTH, FSDP],
66-
[SELF_ATTN_KV_LENGTH, FSDP],
67-
[CROSS_ATTN_HEAD, None],
68-
[CROSS_ATTN_Q_LENGTH, FSDP],
69-
[CROSS_ATTN_KV_LENGTH, FSDP],
70-
]
71-
72-
SEQUENCE_PARALLEL_AXIS_RULES = [
73-
[SELF_ATTN_HEAD, None],
74-
[SELF_ATTN_Q_LENGTH, FSDP],
75-
[SELF_ATTN_KV_LENGTH, None],
76-
[CROSS_ATTN_HEAD, None],
77-
[CROSS_ATTN_Q_LENGTH, FSDP],
78-
[CROSS_ATTN_KV_LENGTH, None],
79-
]

src/maxdiffusion/configs/base14.yml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,6 @@ jit_initializers: True
5050
from_pt: False
5151
split_head_dim: True
5252
attention: 'dot_product' # Supported attention: dot_product, flash
53-
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54-
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55-
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56-
mask_padding_tokens: True
57-
# Maxdiffusion has 2 types of attention sharding strategies:
58-
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59-
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60-
# in cross attention q.
61-
attention_sharding_uniform: True
6253
flash_block_sizes: {}
6354
# GroupNorm groups
6455
norm_num_groups: 32

src/maxdiffusion/configs/base21.yml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,6 @@ jit_initializers: True
4949
from_pt: False
5050
split_head_dim: True
5151
attention: 'dot_product' # Supported attention: dot_product, flash
52-
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
53-
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
54-
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
55-
mask_padding_tokens: True
56-
# Maxdiffusion has 2 types of attention sharding strategies:
57-
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
58-
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
59-
# in cross attention q.
60-
attention_sharding_uniform: True
61-
6252
flash_block_sizes: {}
6353
# GroupNorm groups
6454
norm_num_groups: 32

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,6 @@ jit_initializers: True
5050
from_pt: True
5151
split_head_dim: True
5252
attention: 'flash' # Supported attention: dot_product, flash
53-
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
54-
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
55-
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
56-
mask_padding_tokens: True
57-
# Maxdiffusion has 2 types of attention sharding strategies:
58-
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
59-
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
60-
# in cross attention q.
61-
attention_sharding_uniform: True
62-
6353
flash_block_sizes: {}
6454
# to override default block sizes for flash attention
6555
# flash_block_sizes:

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,6 @@ jit_initializers: True
6363
from_pt: True
6464
split_head_dim: True
6565
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
66-
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
67-
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
68-
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
69-
mask_padding_tokens: True
70-
# Maxdiffusion has 2 types of attention sharding strategies:
71-
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
72-
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
73-
# in cross attention q.
74-
attention_sharding_uniform: True
7566

7667
flash_block_sizes: {}
7768
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,6 @@ jit_initializers: True
6363
from_pt: True
6464
split_head_dim: True
6565
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
66-
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
67-
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
68-
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
69-
mask_padding_tokens: True
70-
# Maxdiffusion has 2 types of attention sharding strategies:
71-
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
72-
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
73-
# in cross attention q.
74-
attention_sharding_uniform: True
7566

7667
#flash_block_sizes: {}
7768
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.

0 commit comments

Comments
 (0)