Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions examples/lumina_cfg_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import sys
import functools
from typing import List, Optional, Tuple, Union, Any, Dict

import time
import torch

import diffusers
from diffusers import LuminaPipeline, DiffusionPipeline

import torch.distributed as dist
from xfuser import xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
get_world_group,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_cfg_group,
)

def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer
original_forward = transformer.forward

@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None,
**kwargs,
):
timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
encoder_mask = torch.chunk(encoder_mask, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
image_rotary_emb = torch.chunk(image_rotary_emb, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]

output = original_forward(
hidden_states,
timestep,
encoder_hidden_states,
encoder_mask,
image_rotary_emb=image_rotary_emb,
cross_attention_kwargs=cross_attention_kwargs,
**kwargs,
)

return_dict = not isinstance(output, tuple)
sample = output[0]
sample = get_cfg_group().all_gather(sample, dim=0)
if return_dict:
return output.__class__(sample, *output[1:])
return (sample, *output[1:])

new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward

if __name__ == "__main__":
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank
device = torch.device(f"cuda:{local_rank}")

initialize_model_parallel(
classifier_free_guidance_degree=engine_config.parallel_config.cfg_degree,
)
pipe = LuminaPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
torch_dtype=torch.bfloat16,
)
pipe = pipe.to(device)

pipe.vae.enable_tiling()

parallelize_transformer(pipe)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
guidance_scale=input_config.guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
).images[0]

end_time = time.time()
elapsed_time = end_time - start_time

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if local_rank == 0:
output.save(f"results/lumina_cfg_{parallel_info}.png")
print(f"epoch time: {elapsed_time:.2f} sec")

dist.destroy_process_group()
40 changes: 40 additions & 0 deletions examples/run_lumina.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
set -x

export PYTHONPATH=$PWD:$PYTHONPATH

# CogVideoX configuration
SCRIPT="lumina_cfg_example.py"
MODEL_ID="Alpha-VLLM/Lumina-Next-SFT-diffusers"
INFERENCE_STEP=50

mkdir -p ./results

# CogVideoX specific task args
TASK_ARGS="--height 1024 --width 1024 --guidance_scale 3.5"

# CogVideoX parallel configuration
N_GPUS=2
PARALLEL_ARGS="--ulysses_degree 1 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"

# Uncomment and modify these as needed
# PIPEFUSION_ARGS="--num_pipeline_patch 8"
# OUTPUT_ARGS="--output_type latent"
# PARALLLEL_VAE="--use_parallel_vae"
ENABLE_TILING="--enable_tiling"
# COMPILE_FLAG="--use_torch_compile"

torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
$PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 0 \
--prompt "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures." \
$CFG_ARGS \
$PARALLLEL_VAE \
$ENABLE_TILING \
$COMPILE_FLAG