|
| 1 | +#!/bin/bash |
| 2 | +bash docker_build_dependency_image.sh |
| 3 | +docker tag maxdiffusion_base_image:latest gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest |
| 4 | +docker push gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest |
| 5 | +CLUSTER_NAME=bodaborg-tpu7x-128 |
| 6 | +DEVICE_TYPE=tpu7x-128 # can change to any size <= tpu7x-256 |
| 7 | +PROJECT=cloud-tpu-multipod-dev |
| 8 | +ZONE=us-central1 |
| 9 | + |
| 10 | +# Please change the RUN_NAME and OUTPUT_DIR to your own GCS bucket path. |
| 11 | +export RUN_NAME=sanbao-wan-v7x-20k-${RANDOM} |
| 12 | +OUTPUT_DIR=gs://sanbao-bucket/wan/${RUN_NAME} |
| 13 | +# OUTPUT_DIR=gs://sanbao-bucket/wan/sanbao-wan-train-test |
| 14 | +DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/train/ |
| 15 | +EVAL_DATA_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/eval_timesteps/ |
| 16 | +SAVE_DATASET_DIR=gs://sanbao-bucket/wan_tfr_dataset_pusa_v1/save/ |
| 17 | +RANDOM=123456789 |
| 18 | +IMAGE_DIR=gcr.io/cloud-tpu-multipod-dev/sanbao/maxdiffusion_base_image:latest |
| 19 | +# IMAGE_DIR=gcr.io/tpu-prod-env-multipod/maxdiffusion_jax_stable_stack_nightly@sha256:fd27d49a3be7f743f08e3b6b03e5ae00196794944310e3fee2a7795b99d81195 |
| 20 | +LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl |
| 21 | + |
| 22 | +xpk workload create \ |
| 23 | +--cluster=$CLUSTER_NAME \ |
| 24 | +--project=$PROJECT \ |
| 25 | +--zone=$ZONE \ |
| 26 | +--device-type=$DEVICE_TYPE \ |
| 27 | +--num-slices=1 \ |
| 28 | +--command=" \ |
| 29 | +pip install . && \ |
| 30 | +gsutil cp gs://libtpu-tpu7x-releases/wheels/libtpu/${LIBTPU_VERSION} . && \ |
| 31 | +python -m pip install ${LIBTPU_VERSION} && \ |
| 32 | +export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true \ |
| 33 | +--xla_tpu_enable_async_collective_fusion=true \ |
| 34 | +--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \ |
| 35 | +--xla_enable_async_all_reduce=true \ |
| 36 | +--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \ |
| 37 | +--xla_max_concurrent_async_all_gathers=4 \ |
| 38 | +--xla_tpu_enable_async_all_to_all=true \ |
| 39 | +--xla_latency_hiding_scheduler_rerun=5 \ |
| 40 | +--xla_tpu_rwb_fusion=false \ |
| 41 | +--xla_tpu_enable_sublane_major_scaling_bitcast_fusion=false \ |
| 42 | +--xla_tpu_impure_enable_packed_bf16_math_ops=false \ |
| 43 | +--xla_tpu_enable_sparse_core_reduce_scatter_v2=true \ |
| 44 | +--xla_tpu_enable_sparse_core_collective_offload_all_gather=true \ |
| 45 | +--xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \ |
| 46 | +--xla_tpu_enable_all_gather_offload_tracing=true \ |
| 47 | +--xla_tpu_use_tc_device_shape_on_sc=true \ |
| 48 | +--xla_tpu_prefer_async_allgather_to_allreduce=true \ |
| 49 | +--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \ |
| 50 | +--xla_tpu_scoped_vmem_limit_kib=65536 \ |
| 51 | +--xla_tpu_enable_tpu_custom_call_scoped_vmem_adjustments=true \ |
| 52 | +--xla_enable_transpose_trace=false' && \ |
| 53 | +echo 'Starting WAN training ...' && \ |
| 54 | +HF_HUB_CACHE=/dev/shm python src/maxdiffusion/train_wan.py \ |
| 55 | + src/maxdiffusion/configs/base_wan_14b.yml \ |
| 56 | + attention='flash' \ |
| 57 | + weights_dtype=bfloat16 \ |
| 58 | + activations_dtype=bfloat16 \ |
| 59 | + guidance_scale=5.0 \ |
| 60 | + flow_shift=5.0 \ |
| 61 | + fps=16 \ |
| 62 | + skip_jax_distributed_system=False \ |
| 63 | + run_name='test-wan-training-new' \ |
| 64 | + output_dir=${OUTPUT_DIR} \ |
| 65 | + train_data_dir=${DATASET_DIR} \ |
| 66 | + load_tfrecord_cached=True \ |
| 67 | + height=1280 \ |
| 68 | + width=720 \ |
| 69 | + num_frames=81 \ |
| 70 | + num_inference_steps=50 \ |
| 71 | + prompt='a japanese pop star young woman with black hair is singing with a smile. She is inside a studio with dim lighting and musical instruments.' \ |
| 72 | + jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \ |
| 73 | + enable_profiler=True \ |
| 74 | + dataset_save_location=${SAVE_DATASET_DIR} \ |
| 75 | + remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \ |
| 76 | + flash_min_seq_length=0 \ |
| 77 | + seed=$RANDOM \ |
| 78 | + skip_first_n_steps_for_profiler=3 \ |
| 79 | + profiler_steps=3 \ |
| 80 | + per_device_batch_size=0.5 \ |
| 81 | + ici_data_parallelism=64 \ |
| 82 | + ici_fsdp_parallelism=2 \ |
| 83 | + ici_tensor_parallelism=1 \ |
| 84 | + allow_split_physical_axes=True \ |
| 85 | + max_train_steps=150 \ |
| 86 | + scan_layers=true \ |
| 87 | + flash_block_sizes='{\"block_q\":2048,\"block_kv_compute\":512,\"block_kv\":2048,\"block_q_dkv\":2048,\"block_kv_dkv\":2048,\"block_kv_dkv_compute\":512,\"use_fused_bwd_kernel\":true}' \ |
| 88 | + " \ |
| 89 | +--base-docker-image=${IMAGE_DIR} \ |
| 90 | +--enable-debug-logs \ |
| 91 | +--workload=${RUN_NAME} \ |
| 92 | +--priority=medium \ |
| 93 | +--max-restarts=0 |
0 commit comments