Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5cfe701
update xdit performance with 2 GPUs
feifeibear Dec 9, 2024
402c405
update chinese version
feifeibear Dec 9, 2024
d5315e7
gradio server update
Dec 9, 2024
0bef809
Merge pull request #90 from xdit-project/1209
JacobKong Dec 9, 2024
b47a158
Merge pull request #91 from hntee/th_gradio
JacobKong Dec 9, 2024
c4a9d77
update README
Dec 9, 2024
910a8de
fix the repeatedsaving for xDiT
Dec 10, 2024
6083d0d
fix float point exception on cuda12
Dec 10, 2024
81ddd2f
Merge pull request #106 from OutstanderWang/main
JacobKong Dec 10, 2024
f3e6f25
Update README.md and README_zh.md for high quality report
ckczzj Dec 11, 2024
9d589df
Update README.md and README_zh.md for high quality report
ckczzj Dec 11, 2024
d25e5e1
Unify docker images and update inference scripts.
shlee007 Dec 11, 2024
80a1a69
update README
shlee007 Dec 11, 2024
164ef2f
Update README.md
zhoudaquan Dec 12, 2024
3ef9a88
Update README.md
zhoudaquan Dec 12, 2024
fae0bec
Merge pull request #115 from shlee007/dev/docker
JacobKong Dec 12, 2024
573b50a
Replace promotional video
ckczzj Dec 13, 2024
036275b
Update README.md
zhoudaquan Dec 16, 2024
66411aa
Update README.md
zhoudaquan Dec 16, 2024
83ab6ec
Update README.md
zhoudaquan Dec 16, 2024
ca2ed2d
Solve the issue #122 by updating inference.py
guankaisi Dec 16, 2024
a0d07e6
update fp8 infer
Dec 17, 2024
752f536
Add news for integration of Diffusers
ckczzj Dec 17, 2024
5a1b765
Fix inconsistent README
ckczzj Dec 17, 2024
144b7b3
Update README.md and README_zh.md
ckczzj Dec 18, 2024
5ab5ede
update fp8 readme
Dec 18, 2024
b2c9c5f
Merge branch 'Tencent:main' into main
mboboGO Dec 18, 2024
d132650
update readme
Dec 18, 2024
dc4b090
Merge branch 'main' of github.com:mboboGO/HunyuanVideo
Dec 18, 2024
a3381d6
Merge pull request #141 from mboboGO/main
JacobKong Dec 18, 2024
7850c47
update README for fp8
Dec 18, 2024
9ba2d3e
update README
Dec 18, 2024
e0ee948
Update projects that use HunyuanVideo
ckczzj Dec 18, 2024
eae6111
Merge pull request #130 from guankaisi/main
JacobKong Dec 18, 2024
b35a6ba
Update requirements.txt
JacobKong Dec 18, 2024
ccdf024
update README
Dec 18, 2024
fbdf7b3
update README
Dec 18, 2024
912950c
Update README.md and README_zh.md
ckczzj Dec 18, 2024
94c45d7
Update README.md and README_zh.md
ckczzj Dec 18, 2024
7e55ff8
Update README.md and README_zh.md
ckczzj Dec 18, 2024
466e1be
Update README.md and README_zh.md
ckczzj Dec 18, 2024
8ee7222
Update README.md
zhoudaquan Dec 18, 2024
c2d5f6a
Update README.md and README_zh.md
ckczzj Dec 19, 2024
7757b97
Update README.md
ckczzj Dec 19, 2024
29442f5
add fsdp for HunyuanVideo
xibosun Dec 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 150 additions & 69 deletions README.md

Large diffs are not rendered by default.

200 changes: 142 additions & 58 deletions README_zh.md

Large diffs are not rendered by default.

Binary file added assets/video_poster.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions ckpts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ HunyuanVideo
│ ├──README.md
│ ├──hunyuan-video-t2v-720p
│ │ ├──transformers
│ │ │ ├──mp_rank_00_model_states.pt
│ │ │ ├──mp_rank_00_model_states_fp8.pt
│ │ │ ├──mp_rank_00_model_states_fp8_map.pt
├ │ ├──vae
│ ├──text_encoder
│ ├──text_encoder_2
Expand Down
41 changes: 0 additions & 41 deletions docker/Dockerfile_xDiT

This file was deleted.

8 changes: 0 additions & 8 deletions environment.yml

This file was deleted.

141 changes: 141 additions & 0 deletions gradio_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os
import time
from pathlib import Path
from loguru import logger
from datetime import datetime
import gradio as gr
import random

from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.config import parse_args
from hyvideo.inference import HunyuanVideoSampler
from hyvideo.constants import NEGATIVE_PROMPT

def initialize_model(model_path):
args = parse_args()
models_root_path = Path(model_path)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")

hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
return hunyuan_video_sampler

def generate_video(
model,
prompt,
resolution,
video_length,
seed,
num_inference_steps,
guidance_scale,
flow_shift,
embedded_guidance_scale
):
seed = None if seed == -1 else seed
width, height = resolution.split("x")
width, height = int(width), int(height)
negative_prompt = "" # not applicable in the inference

outputs = model.predict(
prompt=prompt,
height=height,
width=width,
video_length=video_length,
seed=seed,
negative_prompt=negative_prompt,
infer_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_videos_per_prompt=1,
flow_shift=flow_shift,
batch_size=1,
embedded_guidance_scale=embedded_guidance_scale
)

samples = outputs['samples']
sample = samples[0].unsqueeze(0)

save_path = os.path.join(os.getcwd(), "gradio_outputs")
os.makedirs(save_path, exist_ok=True)

time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
video_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][0]}_{outputs['prompts'][0][:100].replace('/','')}.mp4"
save_videos_grid(sample, video_path, fps=24)
logger.info(f'Sample saved to: {video_path}')

return video_path

def create_demo(model_path, save_path):
model = initialize_model(model_path)

with gr.Blocks() as demo:
gr.Markdown("# Hunyuan Video Generation")

with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="A cat walks on the grass, realistic style.")
with gr.Row():
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1104x832 (4:3, 720p)", "1104x832"),
("832x1104 (3:4, 720p)", "832x1104"),
("960x960 (1:1, 720p)", "960x960"),
# 540p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (9:16, 540p)", "544x960"),
("832x624 (4:3, 540p)", "832x624"),
("624x832 (3:4, 540p)", "624x832"),
("720x720 (1:1, 540p)", "720x720"),
],
value="1280x720",
label="Resolution"
)
video_length = gr.Dropdown(
label="Video Length",
choices=[
("2s(65f)", 65),
("5s(129f)", 129),
],
value=129,
)
num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Number of Inference Steps")
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
with gr.Row(visible=False) as advanced_row:
with gr.Column():
seed = gr.Number(value=-1, label="Seed (-1 for random)")
guidance_scale = gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Guidance Scale")
flow_shift = gr.Slider(0.0, 10.0, value=7.0, step=0.1, label="Flow Shift")
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale")
show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
generate_btn = gr.Button("Generate")

with gr.Column():
output = gr.Video(label="Generated Video")

generate_btn.click(
fn=lambda *inputs: generate_video(model, *inputs),
inputs=[
prompt,
resolution,
video_length,
seed,
num_inference_steps,
guidance_scale,
flow_shift,
embedded_guidance_scale
],
outputs=output
)

return demo

if __name__ == "__main__":
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
server_name = os.getenv("SERVER_NAME", "0.0.0.0")
server_port = int(os.getenv("SERVER_PORT", "8081"))
args = parse_args()
print(args)
demo = create_demo(args.model_base, args.save_path)
demo.launch(server_name=server_name, server_port=server_port)
11 changes: 11 additions & 0 deletions hyvideo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ def add_inference_args(parser: argparse.ArgumentParser):
help="Embeded classifier free guidance scale.",
)

group.add_argument(
"--use-fp8",
action="store_true",
help="Enable use fp8 for inference acceleration."
)

group.add_argument(
"--reproduce",
action="store_true",
Expand All @@ -371,6 +377,11 @@ def add_parallel_args(parser: argparse.ArgumentParser):
default=1,
help="Ulysses degree.",
)
group.add_argument(
"--use-fsdp",
action="store_true",
help="use FSDP.",
)

return parser

Expand Down
18 changes: 11 additions & 7 deletions hyvideo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE
from hyvideo.vae import load_vae
from hyvideo.modules import load_model
from hyvideo.text_encoder import TextEncoder
from hyvideo.utils.data_utils import align_to
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
from hyvideo.modules.fp8_optimization import convert_fp8_linear
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline

Expand Down Expand Up @@ -196,7 +198,13 @@ def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
out_channels=out_channels,
factor_kwargs=factor_kwargs,
)
model = model.to(device)
if args.use_fp8:
convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision])

if args.use_fsdp:
model = FSDP(model, ignored_modules=[model.final_layer])
else:
model = model.to(device)
model = Inference.load_state_dict(args, model, pretrained_model_path)
model.eval()

Expand Down Expand Up @@ -402,6 +410,8 @@ def __init__(
)

self.default_negative_prompt = NEGATIVE_PROMPT
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
parallelize_transformer(self.pipeline)

def load_diffusion_pipeline(
self,
Expand Down Expand Up @@ -521,12 +531,6 @@ def predict(
num_images_per_prompt (int): The number of images per prompt. Default is 1.
infer_steps (int): The number of inference steps. Default is 100.
"""
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
assert seed is not None, \
"You have to set a seed in the distributed environment, please rerun with --seed <your-seed>."

parallelize_transformer(self.pipeline)

out_dict = dict()

# ========================================================================
Expand Down
102 changes: 102 additions & 0 deletions hyvideo/modules/fp8_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os

import torch
import torch.nn as nn
from torch.nn import functional as F

def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
_bits = torch.tensor(bits)
_mantissa_bit = torch.tensor(mantissa_bit)
_sign_bits = torch.tensor(sign_bits)
M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
E = _bits - _sign_bits - M
bias = 2 ** (E - 1) - 1
mantissa = 1
for i in range(mantissa_bit - 1):
mantissa += 1 / (2 ** (i+1))
maxval = mantissa * 2 ** (2**E - 1 - bias)
return maxval

def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
"""
Default is E4M3.
"""
bits = torch.tensor(bits)
mantissa_bit = torch.tensor(mantissa_bit)
sign_bits = torch.tensor(sign_bits)
M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
E = bits - sign_bits - M
bias = 2 ** (E - 1) - 1
mantissa = 1
for i in range(mantissa_bit - 1):
mantissa += 1 / (2 ** (i+1))
maxval = mantissa * 2 ** (2**E - 1 - bias)
minval = - maxval
minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
input_clamp = torch.min(torch.max(x, minval), maxval)
log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
# dequant
qdq_out = torch.round(input_clamp / log_scales) * log_scales
return qdq_out, log_scales

def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
for i in range(len(x.shape) - 1):
scale = scale.unsqueeze(-1)
new_x = x / scale
quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
return quant_dequant_x, scale, log_scales

def fp8_activation_dequant(qdq_out, scale, dtype):
qdq_out = qdq_out.type(dtype)
quant_dequant_x = qdq_out * scale.to(dtype)
return quant_dequant_x

def fp8_linear_forward(cls, original_dtype, input):
weight_dtype = cls.weight.dtype
#####
if cls.weight.dtype != torch.float8_e4m3fn:
maxval = get_fp_maxval()
scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
linear_weight = linear_weight.to(torch.float8_e4m3fn)
weight_dtype = linear_weight.dtype
else:
scale = cls.fp8_scale.to(cls.weight.device)
linear_weight = cls.weight
#####

if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
if True or len(input.shape) == 3:
cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
if cls.bias != None:
output = F.linear(input, cls_dequant, cls.bias)
else:
output = F.linear(input, cls_dequant)
return output
else:
return cls.original_forward(input.to(original_dtype))
else:
return cls.original_forward(input)

def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
setattr(module, "fp8_matmul_enabled", True)

# loading fp8 mapping file
fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
if os.path.exists(fp8_map_path):
fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
else:
raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")

fp8_layers = []
for key, layer in module.named_modules():
if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
fp8_layers.append(key)
original_forward = layer.forward
layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
setattr(layer, "original_forward", original_forward)
setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))


4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
torchvision==0.16.1
opencv-python==4.9.0.80
diffusers==0.30.2
diffusers==0.31.0
transformers==4.46.3
tokenizers==0.20.3
accelerate==1.1.1
Expand All @@ -12,3 +11,4 @@ loguru==0.7.2
imageio==2.34.0
imageio-ffmpeg==0.5.1
safetensors==0.4.3
gradio==5.0.0
Loading