-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When running Wan2.1-T2V-14B-Diffusers model with Ulysses Attention, the last few frames of the generated vedio are completely noisy.
Reproduction
# infer_wan_t2v.py
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanPipeline, ContextParallelConfig, WanTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.hooks import apply_group_offloading
import os
def launched_with_torchrun() -> bool:
return (
"RANK" in os.environ
and "WORLD_SIZE" in os.environ
and "LOCAL_RANK" in os.environ
)
cpu_offload = True
vae_tiling = True
model_id = "/home/weights/Wan2.1-T2V-14B-Diffusers/"
height = 480
width = 832
prompt = "A cat walks on the grass, realistic"
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
try:
if launched_with_torchrun():
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
world_size = torch.distributed.get_world_size()
else:
rank = 0
device = torch.device("cuda")
torch.cuda.set_device(device)
onload_device = device
offload_device = torch.device("cpu")
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.transformer.set_attention_backend("_native_cudnn")
if launched_with_torchrun():
pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=world_size))
if cpu_offload:
apply_group_offloading(
pipe.text_encoder,
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True,
)
pipe.transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True
)
pipe.vae.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True
)
else:
pipe.to(device)
if vae_tiling:
pipe.vae.enable_tiling(
tile_sample_min_height=int(height / 2 * 3),
tile_sample_min_width=int(width / 2 * 3),
tile_sample_stride_height=int((height + 1) / 2),
tile_sample_stride_width=int((width + 1) / 2),
)
output = pipe(
prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=49, guidance_scale=5.0, num_inference_steps=3
).frames[0]
if rank == 0:
export_to_video(output, "output.mp4", fps=16)
except Exception as e:
print(f"An error occurred: {e}")
raise e
finally:
if launched_with_torchrun():
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
run with command
torchrun --nproc_per_node=8 infer_wan_t2v.py
Logs
output.mp4
System Info
diffusers: 0.36.0.dev0
compile the main-branch locally
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working