Skip to content

generate noise when using Ulysses Attention #12567

@TmacAaron

Description

@TmacAaron

Describe the bug

When running Wan2.1-T2V-14B-Diffusers model with Ulysses Attention, the last few frames of the generated vedio are completely noisy.

Reproduction

# infer_wan_t2v.py

import numpy as np

import torch

from diffusers import AutoencoderKLWan, WanPipeline, ContextParallelConfig, WanTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.hooks import apply_group_offloading

import os

def launched_with_torchrun() -> bool:
    return (
        "RANK" in os.environ
        and "WORLD_SIZE" in os.environ
        and "LOCAL_RANK" in os.environ
    )

cpu_offload = True
vae_tiling = True

model_id = "/home/weights/Wan2.1-T2V-14B-Diffusers/"
height = 480
width = 832
prompt = "A cat walks on the grass, realistic"
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

try:
    if launched_with_torchrun():
        torch.distributed.init_process_group("nccl")
        rank = torch.distributed.get_rank()
        device = torch.device("cuda", rank % torch.cuda.device_count())
        world_size = torch.distributed.get_world_size()
    else:
        rank = 0
        device = torch.device("cuda")
    torch.cuda.set_device(device)

    onload_device = device
    offload_device = torch.device("cpu")

    vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
    pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)

    pipe.transformer.set_attention_backend("_native_cudnn")
    
    if launched_with_torchrun():
        pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=world_size))

    if cpu_offload:
        apply_group_offloading(
            pipe.text_encoder,
            onload_device=onload_device,
            offload_device=offload_device,
            offload_type="leaf_level",
            use_stream=True,
        )
        pipe.transformer.enable_group_offload(
            onload_device=onload_device,
            offload_device=offload_device,
            offload_type="leaf_level",
            use_stream=True
        )
        pipe.vae.enable_group_offload(
            onload_device=onload_device,
            offload_device=offload_device,
            offload_type="leaf_level",
            use_stream=True
        )
    else:
        pipe.to(device)

    if vae_tiling:
        pipe.vae.enable_tiling(
            tile_sample_min_height=int(height / 2 * 3),
            tile_sample_min_width=int(width / 2 * 3),
            tile_sample_stride_height=int((height + 1) / 2),
            tile_sample_stride_width=int((width + 1) / 2),
        )

    output = pipe(
        prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=49, guidance_scale=5.0, num_inference_steps=3
    ).frames[0]

    if rank == 0:
        export_to_video(output, "output.mp4", fps=16)

except Exception as e:
    print(f"An error occurred: {e}")
    raise e

finally:
    if launched_with_torchrun():
        if torch.distributed.is_initialized():
            torch.distributed.destroy_process_group()

run with command

torchrun --nproc_per_node=8 infer_wan_t2v.py

Logs

output.mp4

System Info

diffusers: 0.36.0.dev0
compile the main-branch locally

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions