Skip to content

Added flux demo #3418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
431cc4d
Added CPU offloading
cehongwang Mar 26, 2025
8168f92
Chagned CPU offload to default
cehongwang Mar 27, 2025
23ca669
Added support to module with graph break
cehongwang Mar 27, 2025
953f339
Added back the control flag and fixed the CI
cehongwang Apr 7, 2025
797c670
Chagned CPU offload to default
cehongwang Mar 27, 2025
024992d
Added flux demo
cehongwang Feb 27, 2025
5b4beab
changed the file place and deleted unnecessary code
cehongwang Feb 28, 2025
c9573a1
Fixed memory overhead and enabled Flux with Mutable Module
cehongwang Mar 3, 2025
2e90e73
Supported LoRA
cehongwang Mar 13, 2025
ef5bca8
Refined Flux demo, solved a bug of device mismatch, and prototyped Cu…
cehongwang Mar 18, 2025
a34d25c
Enabled Cuda Graph
cehongwang Mar 18, 2025
b8fafae
Enabled weight streaming and CudaGraph. Supported MTTM saving with dy…
cehongwang Mar 18, 2025
b6a96d8
Changed the Refitting test to disable CPU offload
cehongwang Mar 23, 2025
53d06f3
Fixed Cuda Error
cehongwang Mar 23, 2025
51c3a90
Fixed the bug of SDXL Cuda Error
cehongwang Mar 25, 2025
3920a63
Changed the way to enable CudaGraph for MTTM
cehongwang Mar 25, 2025
0cb1dc2
Finalize the refit revision
cehongwang Mar 26, 2025
6066d51
Fixed the comments
cehongwang Mar 27, 2025
d23853d
Correct the flux export example
cehongwang Mar 27, 2025
b7b433a
Added a textbox to display time the generation process takes
cehongwang Mar 31, 2025
7d2e1c3
Added perf script
cehongwang Apr 3, 2025
b941b75
added back control flag
cehongwang Apr 9, 2025
13bd604
trying to add quantization to Flux
cehongwang Apr 12, 2025
e6e817a
Enable int8 and fp8 quantization for FLUX
cehongwang Apr 24, 2025
41f1f80
Optimized FLUX compilation memory usage
cehongwang Apr 25, 2025
1346fd4
Optimized lowering and decomposition to benchmark quantization again
cehongwang May 2, 2025
084724e
Fixed the benchmark typo
cehongwang May 8, 2025
fb373a0
Use MutableTorchTensorRTModule to do quantization
cehongwang May 9, 2025
c67ee2f
Added quantization debug script
cehongwang May 12, 2025
9c7edb2
Fixed fp16 quantization error
cehongwang May 13, 2025
f536ac6
Added converter registration
cehongwang May 28, 2025
27a2001
Deleted unnecessary files
cehongwang Jun 5, 2025
122d192
Fixed the comments
cehongwang Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 221 additions & 0 deletions examples/apps/flux-demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import argparse
import os
import re
import sys
import time

import gradio as gr
import modelopt.torch.quantization as mtq
import torch
import torch_tensorrt
from diffusers import FluxPipeline

# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo"))
from register_sdpa import *

parser = argparse.ArgumentParser(
description="Run Flux quantization with different dtypes"
)

parser.add_argument(
"--dtype",
choices=["fp8", "int8", "fp16"],
default="fp16",
help="Select the data type to use (fp8 or int8 or fp16)",
)
args = parser.parse_args()
# Update enabled precisions based on dtype argument

if args.dtype == "fp8":
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
ptq_config = mtq.FP8_DEFAULT_CFG
elif args.dtype == "int8":
enabled_precisions = {torch.int8, torch.float16}
ptq_config = mtq.INT8_DEFAULT_CFG
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None
elif args.dtype == "fp16":
enabled_precisions = {torch.float16}
print(f"\nUsing {args.dtype}")


DEVICE = "cuda:0"
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
)


pipe.to(DEVICE).to(torch.float16)
backbone = pipe.transformer
backbone.eval()


def filter_func(name):
pattern = re.compile(
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
)
return pattern.match(name) is not None


def do_calibrate(
pipe,
prompt: str,
) -> None:
"""
Run calibration steps on the pipeline using the given prompts.
"""
image = pipe(
prompt,
output_type="pil",
num_inference_steps=20,
generator=torch.Generator("cuda").manual_seed(0),
).images[0]


def forward_loop(mod):
# Switch the pipeline's backbone, run calibration
pipe.transformer = mod
do_calibrate(
pipe=pipe,
prompt="test",
)


if args.dtype != "fp16":
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
mtq.disable_quantizer(backbone, filter_func)

batch_size = 2

BATCH = torch.export.Dim("batch", min=1, max=8)
dynamic_shapes = {
"hidden_states": {0: BATCH},
"encoder_hidden_states": {0: BATCH},
"pooled_projections": {0: BATCH},
"timestep": {0: BATCH},
"txt_ids": {},
"img_ids": {},
"guidance": {0: BATCH},
"joint_attention_kwargs": {},
"return_dict": None,
}

settings = {
"strict": False,
"allow_complex_guards_as_runtime_asserts": True,
"enabled_precisions": enabled_precisions,
"truncate_double": True,
"min_block_size": 1,
"debug": False,
"use_python_runtime": True,
"immutable_weights": False,
"offload_module_to_cpu": True,
}

trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
pipe.transformer = trt_gm


def generate_image(prompt, inference_step, batch_size=2):
start_time = time.time()
image = pipe(
prompt,
output_type="pil",
num_inference_steps=inference_step,
num_images_per_prompt=batch_size,
).images
end_time = time.time()
return image, end_time - start_time


generate_image(["Test"], 2)
torch.cuda.empty_cache()


def model_change(model):
if model == "Torch Model":
pipe.transformer = backbone
backbone.to(DEVICE)
else:
backbone.to("cpu")
pipe.transformer = trt_gm
torch.cuda.empty_cache()


def load_lora(path):
pipe.load_lora_weights(
path,
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()
print("LoRA loaded! Begin refitting")
generate_image(["Test"], 2)
print("Refitting Finished!")


# Create Gradio interface
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")

with gr.Row():
with gr.Column():
# Input components
prompt_input = gr.Textbox(
label="Prompt", placeholder="Enter your prompt here...", lines=3
)
model_dropdown = gr.Dropdown(
choices=["Torch Model", "Torch-TensorRT Accelerated Model"],
value="Torch-TensorRT Accelerated Model",
label="Model Variant",
)

lora_upload_path = gr.Textbox(
label="LoRA Path",
placeholder="Enter the LoRA checkpoint path here. It could be a local path or a Hugging Face URL.",
value="gokaygokay/Flux-Engrave-LoRA",
lines=2,
)
num_steps = gr.Slider(
minimum=20, maximum=100, value=20, step=1, label="Inference Steps"
)
batch_size = gr.Slider(
minimum=1, maximum=8, value=1, step=1, label="Batch Size"
)

generate_btn = gr.Button("Generate Image")
load_lora_btn = gr.Button("Load LoRA")

with gr.Column():
# Output component
output_image = gr.Gallery(label="Generated Image")
time_taken = gr.Textbox(
label="Generation Time (seconds)", interactive=False
)

# Connect the button to the generation function
model_dropdown.change(model_change, inputs=[model_dropdown])
load_lora_btn.click(
fn=load_lora,
inputs=[
lora_upload_path,
],
)

# Update generate button click to include time output
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
num_steps,
batch_size,
],
outputs=[output_image, time_taken],
)

# Launch the interface
if __name__ == "__main__":
demo.launch()
9 changes: 4 additions & 5 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from diffusers import DiffusionPipeline

np.random.seed(5)
torch.manual_seed(5)
Expand All @@ -31,7 +32,7 @@
# Initialize the Mutable Torch TensorRT Module with settings.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
settings = {
"use_python": False,
"use_python_runtime": False,
"enabled_precisions": {torch.float32},
"immutable_weights": False,
}
Expand All @@ -40,7 +41,6 @@
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
mutable_module(*inputs)

# %%
# Make modifications to the mutable module.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -73,13 +73,12 @@
# Stable Diffusion with Huggingface
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

from diffusers import DiffusionPipeline

with torch.no_grad():
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"debug": False,
"immutable_weights": False,
}

Expand All @@ -106,7 +105,7 @@
"text_embeds": {0: BATCH},
"time_ids": {0: BATCH},
},
"return_dict": False,
"return_dict": None,
}
pipe.unet.set_expected_dynamic_shape_range(
args_dynamic_shapes, kwargs_dynamic_shapes
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
)

# Check the output
model2.to("cuda")
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assert torch.allclose(
Expand Down
Loading