diff --git a/MODULE.bazel b/MODULE.bazel index b7a18e9eaa..a22e70f071 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -37,7 +37,7 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local. new_local_repository( name = "cuda", build_file = "@//third_party/cuda:BUILD", - path = "/usr/local/cuda-12.8/", + path = "/usr/local/cuda-12.9/", ) # for Jetson diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py new file mode 100644 index 0000000000..e222b4f772 --- /dev/null +++ b/examples/apps/flux_demo.py @@ -0,0 +1,265 @@ +import argparse +import os +import re +import sys +import time + +import gradio as gr +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from accelerate.hooks import remove_hook_from_module +from diffusers import FluxPipeline +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel + +# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py +sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) +from register_sdpa import * + +DEVICE = "cuda:0" + + +def compile_model( + args, +) -> tuple[ + FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule +]: + + if args.dtype == "fp8": + enabled_precisions = {torch.float8_e4m3fn, torch.float16} + ptq_config = mtq.FP8_DEFAULT_CFG + + elif args.dtype == "int8": + enabled_precisions = {torch.int8, torch.float16} + ptq_config = mtq.INT8_DEFAULT_CFG + ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None + + elif args.dtype == "fp16": + enabled_precisions = {torch.float16} + + print(f"\nUsing {args.dtype}") + + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, + ).to(torch.float16) + + if args.low_vram_mode: + pipe.enable_model_cpu_offload() + else: + pipe.to(DEVICE) + + backbone = pipe.transformer + backbone.eval() + + def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + + def do_calibrate( + pipe, + prompt: str, + ) -> None: + """ + Run calibration steps on the pipeline using the given prompts. + """ + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(0), + ).images[0] + + def forward_loop(mod): + # Switch the pipeline's backbone, run calibration + pipe.transformer = mod + do_calibrate( + pipe=pipe, + prompt="a dog running in a park", + ) + + if args.dtype != "fp16": + backbone = mtq.quantize(backbone, ptq_config, forward_loop) + mtq.disable_quantizer(backbone, filter_func) + + batch_size = 2 if args.dynamic_shapes else 1 + if args.dynamic_shapes: + BATCH = torch.export.Dim("batch", min=1, max=8) + dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {}, + "img_ids": {}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, + } + else: + dynamic_shapes = None + + settings = { + "strict": False, + "allow_complex_guards_as_runtime_asserts": True, + "enabled_precisions": enabled_precisions, + "truncate_double": True, + "min_block_size": 1, + "debug": False, + "use_python_runtime": True, + "immutable_weights": False, + "offload_module_to_cpu": True, + } + if args.low_vram_mode: + pipe.remove_all_hooks() + pipe.enable_sequential_cpu_offload() + remove_hook_from_module(pipe.transformer, recurse=True) + pipe.transformer.to(DEVICE) + trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) + if dynamic_shapes: + trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) + pipe.transformer = trt_gm + + image = pipe( + "Test", + output_type="pil", + num_inference_steps=2, + num_images_per_prompt=batch_size, + ).images + + torch.cuda.empty_cache() + + if args.low_vram_mode: + pipe.remove_all_hooks() + pipe.to(DEVICE) + + return pipe, backbone, trt_gm + + +def launch_gradio(pipeline, backbone, trt_gm): + + def generate_image(prompt, inference_step, batch_size=2): + start_time = time.time() + image = pipeline( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end_time = time.time() + return image, end_time - start_time + + def model_change(model): + if model == "Torch Model": + pipeline.transformer = backbone + backbone.to(DEVICE) + else: + backbone.to("cpu") + pipeline.transformer = trt_gm + torch.cuda.empty_cache() + + def load_lora(path): + pipeline.load_lora_weights( + path, + adapter_name="lora1", + ) + pipeline.set_adapters(["lora1"], adapter_weights=[1]) + pipeline.fuse_lora() + pipeline.unload_lora_weights() + print("LoRA loaded! Begin refitting") + generate_image(pipeline, ["Test"], 2) + print("Refitting Finished!") + + # Create Gradio interface + with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo: + gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT") + + with gr.Row(): + with gr.Column(): + # Input components + prompt_input = gr.Textbox( + label="Prompt", placeholder="Enter your prompt here...", lines=3 + ) + model_dropdown = gr.Dropdown( + choices=["Torch Model", "Torch-TensorRT Accelerated Model"], + value="Torch-TensorRT Accelerated Model", + label="Model Variant", + ) + + lora_upload_path = gr.Textbox( + label="LoRA Path", + placeholder="Enter the LoRA checkpoint path here. It could be a local path or a Hugging Face URL.", + value="gokaygokay/Flux-Engrave-LoRA", + lines=2, + ) + num_steps = gr.Slider( + minimum=20, maximum=100, value=20, step=1, label="Inference Steps" + ) + batch_size = gr.Slider( + minimum=1, maximum=8, value=1, step=1, label="Batch Size" + ) + + generate_btn = gr.Button("Generate Image") + load_lora_btn = gr.Button("Load LoRA") + + with gr.Column(): + # Output component + output_image = gr.Gallery(label="Generated Image") + time_taken = gr.Textbox( + label="Generation Time (seconds)", interactive=False + ) + + # Connect the button to the generation function + model_dropdown.change(model_change, inputs=[model_dropdown]) + load_lora_btn.click( + fn=load_lora, + inputs=[ + lora_upload_path, + ], + ) + + # Update generate button click to include time output + generate_btn.click( + fn=generate_image, + inputs=[ + prompt_input, + num_steps, + batch_size, + ], + outputs=[output_image, time_taken], + ) + demo.launch() + + +def main(args): + pipe, backbone, trt_gm = compile_model(args) + launch_gradio(pipe, backbone, trt_gm) + + +# Launch the interface +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" + ) + + parser.add_argument( + "--dtype", + choices=["fp8", "int8", "fp16"], + default="fp16", + help="Select the data type to use (fp8 or int8 or fp16)", + ) + parser.add_argument( + "--low_vram_mode", + action="store_true", + help="Use low VRAM mode when you have a small GPU (<=32GB)", + ) + parser.add_argument( + "--dynamic_shapes", + "-d", + action="store_true", + help="Use dynamic shapes", + ) + args = parser.parse_args() + main(args) diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index a6c8a5384e..665bda1b51 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -22,6 +22,7 @@ import torch import torch_tensorrt as torch_trt import torchvision.models as models +from diffusers import DiffusionPipeline np.random.seed(5) torch.manual_seed(5) @@ -31,7 +32,7 @@ # Initialize the Mutable Torch TensorRT Module with settings. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ settings = { - "use_python": False, + "use_python_runtime": False, "enabled_precisions": {torch.float32}, "immutable_weights": False, } @@ -40,7 +41,6 @@ mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. mutable_module(*inputs) - # %% # Make modifications to the mutable module. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -73,13 +73,12 @@ # Stable Diffusion with Huggingface # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -from diffusers import DiffusionPipeline with torch.no_grad(): settings = { "use_python_runtime": True, "enabled_precisions": {torch.float16}, - "debug": True, + "debug": False, "immutable_weights": False, } @@ -106,7 +105,7 @@ "text_embeds": {0: BATCH}, "time_ids": {0: BATCH}, }, - "return_dict": False, + "return_dict": None, } pipe.unet.set_expected_dynamic_shape_range( args_dynamic_shapes, kwargs_dynamic_shapes diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index 66a1a70964..51202528c5 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -101,6 +101,7 @@ ) # Check the output +model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs) for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): assert torch.allclose( diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py index 32e75b06d8..4a6d36a960 100644 --- a/examples/dynamo/torch_export_flux_dev.py +++ b/examples/dynamo/torch_export_flux_dev.py @@ -114,6 +114,8 @@ min_block_size=1, use_fp32_acc=True, use_explicit_typing=True, + immutable_weights=False, + offload_module_to_cpu=True, ) # %% @@ -121,14 +123,13 @@ # --------------------------- # Release the GPU memory occupied by the exported program and the pipe.transformer # Set the transformer in the Flux pipeline to the Torch-TRT compiled model - -del ep -backbone.to("cpu") +pipe.transformer = None pipe.to(DEVICE) -torch.cuda.empty_cache() pipe.transformer = trt_gm +del ep +torch.cuda.empty_cache() pipe.transformer.config = config - +trt_gm.device = torch.device("cuda") # %% # Image generation using prompt # --------------------------- diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 830faf3373..e14a449aed 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -693,6 +693,7 @@ def compile( ) gm = exported_program.module() + # Move the weights in the state_dict to CPU logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module @@ -914,7 +915,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: parse_graph_io(submodule, subgraph_data) dryrun_tracker.tensorrt_graph_count += 1 dryrun_tracker.per_subgraph_data.append(subgraph_data) - + torch.cuda.empty_cache() # Create TRT engines from submodule if not settings.dryrun: trt_module = convert_module( diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 23648facaf..15136a5170 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -2,6 +2,7 @@ import collections.abc import copy +import gc import logging from typing import Any, List, Optional, Sequence, Tuple @@ -36,7 +37,9 @@ TorchTensorRTModule, ) from torch_tensorrt.dynamo.utils import ( + CPU_DEVICE, check_module_output, + deallocate_module, get_model_device, get_torch_inputs, set_log_level, @@ -289,42 +292,68 @@ def refit_module_weights( get_decompositions(settings.enable_experimental_decompositions) ) new_gm = new_weight_module.module() + logger.debug("Input graph: " + str(new_gm.graph)) # Apply lowering on the graph module new_gm = post_lowering(new_gm, settings) - logger.info("Compilation Settings: %s\n", settings) + logger.debug("Lowered Input graph: " + str(new_gm.graph)) # Set torch-executed ops - CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) + CONVERTERS.set_compilation_settings(settings) + + # Check the number of supported operations in the graph + num_supported_ops, total_ops = partitioning.get_graph_converter_support( + new_gm, settings.debug, settings.torch_executed_ops + ) + + if num_supported_ops == 0 or ( + num_supported_ops < settings.min_block_size and not settings.dryrun + ): + logger.warning( + f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. " + f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}" + ) + return new_gm + else: + logger.debug( + f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph." + ) # If specified, try using the fast partitioner and fall back to the global one on failure if settings.use_fast_partitioner: try: + logger.info("Partitioning the graph via the fast partitioner") new_partitioned_module, supported_ops = partitioning.fast_partition( new_gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=(num_supported_ops == total_ops), ) + except torch.fx.passes.splitter_base.FxNetSplitterInternalError: logger.error( "Partitioning failed on the subgraph with fast partition. See trace above. " - + "Retrying with global partition.", + "Retrying with global partition.", exc_info=True, ) settings.use_fast_partitioner = False if not settings.use_fast_partitioner: + logger.info("Partitioning the graph via the global partitioner") new_partitioned_module, supported_ops = partitioning.global_partition( new_gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, ) + # Done Partition if inline_module: # Preprocess the partitioned module to be in the same format as the inline module inline_torch_modules(new_partitioned_module) @@ -341,7 +370,7 @@ def refit_module_weights( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those - + new_weight_module.module().to(CPU_DEVICE) for name, new_submodule in new_partitioned_module.named_children(): # Refit each submodule # Extract engine from the submodule @@ -444,26 +473,33 @@ def refit_module_weights( settings=settings, weight_name_map=None, ) + deallocate_module(new_submodule) # clear EXCLUDE_WEIGHTS flag serialization_config = engine.create_serialization_config() serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) serialized_engine = engine.serialize_with_config(serialization_config) - if isinstance( - compiled_submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) - ): + if isinstance(compiled_submodule, PythonTorchTensorRTModule): + compiled_submodule.serialized_engine = bytes(serialized_engine) + elif isinstance(compiled_submodule, TorchTensorRTModule): compiled_submodule.engine = None # Clear the engine for TorchTensorRTModule, otherwise it won't be updated compiled_submodule.serialized_engine = bytes(serialized_engine) compiled_submodule.setup_engine() - elif inline_module: new_engine_info = list(engine_info) new_engine_info[ENGINE_IDX] = bytes(serialized_engine) refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) setattr(compiled_module, f"{name}_engine", refitted_engine) + del engine + gc.collect() + torch.cuda.empty_cache() + + deallocate_module(new_partitioned_module) + if verify_output and arg_inputs is not None: + new_gm.to(to_torch_device(settings.device)) if check_module_output( new_module=new_gm, refitted_module=compiled_module, @@ -471,6 +507,7 @@ def refit_module_weights( kwarg_inputs=torch_kwarg_inputs, ): logger.info("Refitting Succeed!") + new_gm.to(CPU_DEVICE) else: if weight_name_map: logger.warning( @@ -486,6 +523,7 @@ def refit_module_weights( in_place=in_place, ) logger.error("Refitting Failed! The outputs do not match.") + new_gm.to(CPU_DEVICE) else: logger.info("Refitting Completed! Output verification skipped.") diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index dd6bf346f8..cef00f3a2a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -494,10 +494,8 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - sd = { - k: v.reshape(-1).to(torch_device) - for k, v in self.module.state_dict().items() - } + self.module.to(torch_device) + sd = self.module.state_dict() weight_name_map: dict[str, Any] = {} np_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in np_map.items() if v.size == 1} @@ -570,7 +568,6 @@ def _save_weight_mapping(self) -> None: weight_name_map["constant_mapping"] = constant_mapping self.weight_name_map = weight_name_map - del np_map, sd gc.collect() torch.cuda.empty_cache() @@ -721,6 +718,9 @@ def run( if not self.compilation_settings.immutable_weights: self._save_weight_mapping() + if self.compilation_settings.offload_module_to_cpu: + deallocate_module(self.module) + build_engine_start_time = datetime.now() _LOGGER.info("Not found cached TRT engines. Start building engine.") @@ -731,8 +731,6 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - if self.compilation_settings.offload_module_to_cpu: - deallocate_module(self.module) serialized_engine = self.builder.build_serialized_network( self.ctx.net, builder_config ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index 46cb00f45c..9c1d95b585 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -6,12 +6,31 @@ from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_torch from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor +def get_ir(target: Target) -> SourceIR: + target_module = getattr(target, "__module__", "None") + if any( + target_module.startswith(prefix) + for prefix in ("torch.ops.aten", "torch._ops.aten") + ): + return SourceIR.ATEN + elif any( + target_module.startswith(prefix) + for prefix in ("torch.ops.prims", "torch._ops.prims") + ): + return SourceIR.PRIM + elif target_module.startswith("torch.nn"): + return SourceIR.NN + + return SourceIR.UNKNOWN + + def quantize( ctx: ConversionContext, target: Target, @@ -56,7 +75,6 @@ def quantize( dtype = trt.DataType.FP8 max_bound = 448 - amax = to_torch(amax, None) axis = None # int8 weight quantization is per-channel quantization(it can have one or multiple amax values) if dtype == trt.DataType.INT8 and amax.numel() > 1: @@ -75,10 +93,32 @@ def quantize( assert ( amax.numel() == 1 ), f"{name=} is per-tensor quantization, expected amax is a singular value, but got {amax.shape=}" - scale = torch.divide(amax, max_bound) - scale.masked_fill_(scale == 0, 1.0) - scale = get_trt_tensor(ctx, scale, name + "_scale") - input_tensor = get_trt_tensor(ctx, input_tensor, name) + + if not isinstance(amax, trt.ITensor): + amax = to_torch(amax, None) + scale = torch.divide(amax, max_bound) + scale = get_trt_tensor(ctx, scale, name + "_scale", dtype=torch.float32) + else: + scale = impl.elementwise.div( + ctx, + target, + get_ir(target), + name, + amax, + max_bound, + ) + scale = get_trt_tensor(ctx, scale, name + "_scale", dtype=torch.float32) + + # Add Q node + if num_bits == 8 and exponent_bits == 0: + dtype = trt.DataType.INT8 + elif num_bits == 8 and exponent_bits == 4: + dtype = trt.DataType.FP8 + + if not isinstance(input_tensor, TRTTensor): + input_tensor = get_trt_tensor(ctx, input_tensor, name + "_quantize_input") + + quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype) # Add Q node quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype) @@ -90,7 +130,6 @@ def quantize( dequantize_layer = ctx.net.add_dequantize( q_output, scale, output_type=input_tensor.dtype ) - dequantize_layer.to_type = input_tensor.dtype if axis is not None: dequantize_layer.axis = axis set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 19e97ef099..928b7284fe 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -98,20 +98,21 @@ def replace_node_with_constant( class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - - def is_impure(self, node: torch.fx.node.Node) -> bool: # Set of known quantization ops to be excluded from constant folding. # Currently, we exclude all quantization ops coming from modelopt library. - quantization_ops: Set[torch._ops.OpOverload] = set() + self.quantization_ops: Set[torch._ops.OpOverload] = set() try: # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered - import modelopt.torch.quantization as mtq # noqa: F401 + import modelopt.torch.quantization as mtq assert torch.ops.tensorrt.quantize_op.default - quantization_ops.add(torch.ops.tensorrt.quantize_op.default) - quantization_ops.add(torch.ops.tensorrt.dynamic_block_quantize_op.default) + self.quantization_ops.add(torch.ops.tensorrt.quantize_op.default) except Exception as e: pass - if quantization_ops and node.target in quantization_ops: + + # TODO: Update this function when quantization is added + def is_impure(self, node: torch.fx.node.Node) -> bool: + + if node.target in self.quantization_ops: return True return False diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index eaeb6a8c28..cd732811b3 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -2,19 +2,19 @@ import logging from copy import deepcopy from enum import Enum, auto -from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union +from typing import Any, Dict, Iterator, Optional, Union import numpy as np import torch -from torch.fx.node import Target +import torch_tensorrt +from torch.export._trace import _export from torch_tensorrt._Device import Device -from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._compiler import compile as dynamo_compile from torch_tensorrt.dynamo._refit import refit_module_weights -from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.utils import ( check_output_equal, + deallocate_module, to_torch_device, to_torch_tensorrt_device, ) @@ -63,35 +63,11 @@ def __init__( pytorch_model: torch.nn.Module, *, device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, - disable_tf32: bool = _defaults.DISABLE_TF32, - assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, - sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - enabled_precisions: Set[ - Union[torch.dtype, dtype] - ] = _defaults.ENABLED_PRECISIONS, - engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - immutable_weights: bool = False, - debug: bool = _defaults.DEBUG, - num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, - workspace_size: int = _defaults.WORKSPACE_SIZE, - dla_sram_size: int = _defaults.DLA_SRAM_SIZE, - dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, - truncate_double: bool = _defaults.TRUNCATE_DOUBLE, - require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, - min_block_size: int = _defaults.MIN_BLOCK_SIZE, - torch_executed_ops: Optional[Collection[Target]] = None, - torch_executed_modules: Optional[List[str]] = None, - pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, - version_compatible: bool = _defaults.VERSION_COMPATIBLE, - optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME, - use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, - enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - dryrun: bool = _defaults.DRYRUN, - hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, - timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + immutable_weights: bool = False, + strict: bool = True, + allow_complex_guards_as_runtime_asserts: bool = False, + weight_streaming_budget: Optional[int] = None, **kwargs: Any, ) -> None: """ @@ -154,53 +130,35 @@ def __init__( self.exp_program: Any = None self.arg_inputs: tuple[Any, ...] = tuple() self.kwarg_inputs: dict[str, Any] = {} - device = to_torch_tensorrt_device(device) - enabled_precisions = {dtype._from(p) for p in enabled_precisions} + self.additional_settings = kwargs + self.strict = strict + self.allow_complex_guards_as_runtime_asserts = ( + allow_complex_guards_as_runtime_asserts + ) + self.use_python_runtime = use_python_runtime + self.trt_device = to_torch_tensorrt_device(device) assert ( not immutable_weights - ), "`immutable_weights` has to be False for a MutableTorchTensorRTModule." - compilation_options = { - "enabled_precisions": ( - enabled_precisions - if enabled_precisions - else _defaults.ENABLED_PRECISIONS - ), - "debug": debug, - "device": device, - "assume_dynamic_shape_support": assume_dynamic_shape_support, - "workspace_size": workspace_size, - "min_block_size": min_block_size, - "torch_executed_ops": ( - torch_executed_ops if torch_executed_ops is not None else set() - ), - "pass_through_build_failures": pass_through_build_failures, - "max_aux_streams": max_aux_streams, - "version_compatible": version_compatible, - "optimization_level": optimization_level, - "use_python_runtime": use_python_runtime, - "truncate_double": truncate_double, - "use_fast_partitioner": use_fast_partitioner, - "num_avg_timing_iters": num_avg_timing_iters, - "enable_experimental_decompositions": enable_experimental_decompositions, - "require_full_compilation": require_full_compilation, - "disable_tf32": disable_tf32, - "sparse_weights": sparse_weights, - "immutable_weights": immutable_weights, - "engine_capability": engine_capability, - "dla_sram_size": dla_sram_size, - "dla_local_dram_size": dla_local_dram_size, - "dla_global_dram_size": dla_global_dram_size, - "dryrun": dryrun, - "hardware_compatible": hardware_compatible, - "timing_cache_path": timing_cache_path, - } + ), "`immutable_weights has to be False for a MutableTorchTensorRTModule" + self.arg_dynamic_shapes: Optional[tuple[Any]] = None self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None - - self.settings = CompilationSettings(**compilation_options) + self.serializable_dynamic_shapes_dims: dict[str, tuple[str, int, int]] = {} self.run_info: Optional[tuple[Any, ...]] = None self.state_dict_metadata: dict[str, torch.Size] = {} self._store_state_dict_metadata() + self.enable_weight_streaming = ( + kwargs["enable_weight_streaming"] + if "enable_weight_streaming" in kwargs + else False + ) + self.weight_streaming_ctx = None + self.weight_streaming_budget = weight_streaming_budget + if self.enable_weight_streaming: + if weight_streaming_budget is None: + logger.warning( + "Weight stremaing budget is not set. Using auto weight streaming budget" + ) cls = self.__class__ self.__class__ = type( @@ -293,10 +251,9 @@ def update_refit_condition(self) -> None: # to determine whether refit/recompilation is needed. If the output is the same, no further process needed. if self.run_info: args, kwargs, result = self.run_info - self.original_model.to(to_torch_device(self.settings.device)) + self.original_model.to(to_torch_device(self.trt_device)) new_result = self.original_model(*args, **kwargs) - self.original_model.cpu() - torch.cuda.empty_cache() + deallocate_module(self.original_model, delete_module=False) if check_output_equal(result, new_result): self.refit_state.set_state(RefitFlag.LIVE) return @@ -325,17 +282,17 @@ def refit_gm(self) -> None: MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module. If it fails to catch the changes, please call this function manually to update the TRT graph module. """ - self.original_model.to(to_torch_device(self.settings.device)) + if self.exp_program is None: - self.exp_program = torch.export.export( - self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs - ) + self.original_model.to(to_torch_device(self.trt_device)) + self.exp_program = self.get_exported_program() else: self.exp_program._state_dict = ( MutableTorchTensorRTModule._transform_state_dict( self.original_model.state_dict() ) ) + self.exp_program.module().to(to_torch_device(self.trt_device)) self.gm = refit_module_weights( self.gm, self.exp_program, @@ -345,8 +302,45 @@ def refit_gm(self) -> None: in_place=True, ) - self.original_model.cpu() - torch.cuda.empty_cache() + deallocate_module(self.original_model, delete_module=False) + + def get_exported_program(self) -> torch.export.ExportedProgram: + + def export_fn() -> torch.export.ExportedProgram: + if self.allow_complex_guards_as_runtime_asserts: + return _export( + self.original_model, + self.arg_inputs, + kwargs=self.kwarg_inputs, + dynamic_shapes=self._get_total_dynamic_shapes(), + strict=self.strict, + allow_complex_guards_as_runtime_asserts=self.allow_complex_guards_as_runtime_asserts, + ) + else: + return torch.export.export( + self.original_model, + self.arg_inputs, + kwargs=self.kwarg_inputs, + dynamic_shapes=self._get_total_dynamic_shapes(), + strict=self.strict, + ) + + if ( + torch.float8_e4m3fn in self.additional_settings["enabled_precisions"] + or torch.int8 in self.additional_settings["enabled_precisions"] + ): + try: + from modelopt.torch.quantization.utils import export_torch_mode + + assert torch.ops.tensorrt.quantize_op.default + except Exception as e: + logger.warning( + "Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models" + ) + with export_torch_mode(): + return export_fn() + else: + return export_fn() def compile(self) -> None: """ @@ -356,25 +350,36 @@ def compile(self) -> None: If it fails to catch the changes, please call this function manually to recompile the TRT graph module. """ # Export the module - self.original_model.to(to_torch_device(self.settings.device)) - self.exp_program = torch.export.export( - self.original_model, - self.arg_inputs, - kwargs=self.kwarg_inputs, - dynamic_shapes=self._get_total_dynamic_shapes(), - ) + self.original_model.to(to_torch_device(self.trt_device)) + self.exp_program = self.get_exported_program() self.gm = dynamo_compile( self.exp_program, arg_inputs=self.arg_inputs, kwarg_inputs=self.kwarg_inputs, - **self.settings.__dict__, + immutable_weights=False, + use_python_runtime=self.use_python_runtime, + **self.additional_settings, + ) + deallocate_module(self.original_model, delete_module=False) + if self.enable_weight_streaming: + self.set_weight_streaming_ctx(self.weight_streaming_budget) + + def set_weight_streaming_ctx(self, requested_budget: Optional[int] = None) -> None: + """ + Set the weight streaming budget. If budget is not set, then automatic weight streaming budget + is used. + """ + self.weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(self.gm) + requested_budget = ( + requested_budget + if requested_budget is not None + else self.weight_streaming_ctx.get_automatic_weight_streaming_budget() ) - self.original_model.cpu() - torch.cuda.empty_cache() + self.weight_streaming_ctx.device_budget = requested_budget def _validate_inputs(self, *args: Any, **kwargs: Any) -> None: - if not self.arg_inputs: + if not self.arg_inputs and not self.kwarg_inputs: logger.info("First time compilation initiated. This may take some time.") self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) self._store_inputs(args, kwargs) @@ -491,14 +496,24 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: self._store_state_dict_metadata() self.refit_state.set_state(RefitFlag.LIVE) + weight_streaming_ctx = ( + self.weight_streaming_ctx if self.enable_weight_streaming else None + ) result = self.gm(*args, **kwargs) # Storing inputs and outputs for verification when the state is unknown self.run_info = (args, kwargs, result) return result - def to(self, device: str) -> None: - logger.warning("Original PyTorch model is moved. CPU offload may failed.") - self.original_model.to(device) + def to(self, *args: Any, **kwargs: Any) -> None: + logger.warning( + "Trying to move the original PyTorch model. This will cause CPU offloading failing and increase GPU memory usage." + + "If this is absolute necessary, please call module.pytorch_model.to(...) \n" + + "The model is still on the original device." + ) + + @property + def device(self) -> torch.device: + return to_torch_device(self.trt_device) def __deepcopy__(self, memo: Any) -> Any: cls = self.__class__ @@ -624,18 +639,58 @@ def _check_tensor_shapes_with_dynamic_shapes( return True + def serialize_dynamic_shapes(self) -> None: + dims = self.serializable_dynamic_shapes_dims + + def resursivly_serialize_dynamic_shape(obj: Any) -> None: + if isinstance(obj, dict): + for axis, v in obj.items(): + if isinstance(v, torch.export.dynamic_shapes._Dim): + name = str(v).split("'")[1].split(".")[-1] + # We use string of the hash to be the unique identifier of Dim object + dims.setdefault(str(hash(v)), (name, v.min, v.max)) + obj[axis] = str(hash(v)) + else: + resursivly_serialize_dynamic_shape(v) + if isinstance(obj, (tuple, list)): + for v in obj: + resursivly_serialize_dynamic_shape(v) + + resursivly_serialize_dynamic_shape(self.arg_dynamic_shapes) + resursivly_serialize_dynamic_shape(self.kwarg_dynamic_shapes) + + def deserialize_dynamic_shapes(self) -> None: + dims = self.serializable_dynamic_shapes_dims + + def resursivly_deserialize_dynamic_shape(obj: Any) -> None: + if isinstance(obj, dict): + for axis, v in obj.items(): + if isinstance(v, str): + obj[axis] = torch.export.Dim( + dims[v][0], min=dims[v][1], max=dims[v][2] + ) + else: + resursivly_deserialize_dynamic_shape(v) + if isinstance(obj, (tuple, list)): + for v in obj: + resursivly_deserialize_dynamic_shape(v) + + resursivly_deserialize_dynamic_shape(self.arg_dynamic_shapes) + resursivly_deserialize_dynamic_shape(self.kwarg_dynamic_shapes) + @staticmethod def save(module: Any, path: str) -> None: # Cast the object back to MutableTorchTensorRTModule to save assert ( - not module.settings.use_python_runtime + not module.use_python_runtime ), "Python runtime does not support serialization. Save failed." module.init_finished = False module.__class__ = MutableTorchTensorRTModule exp_program = module.exp_program module.pytorch_model = None module.exp_program = None - torch.save(module, path) + module.serialize_dynamic_shapes() + torch.save(module, path, pickle_protocol=4) # Restore deleted attributes module.exp_program = exp_program module.pytorch_model = _make_refit_change_trigger( @@ -658,20 +713,27 @@ def load(path: str) -> Any: module.pytorch_model = _make_refit_change_trigger( module.original_model, module.refit_state ) - module.original_model.to(to_torch_device(module.settings.device)) + module.original_model.to(to_torch_device(module.device)) module.exp_program = torch.export.export( module.original_model, module.arg_inputs, kwargs=module.kwarg_inputs ) - module.original_model.to("cpu") + deallocate_module(module.original_model, delete_module=False) cls = module.__class__ module.__class__ = type( module.original_model.__class__.__name__, (cls, module.original_model.__class__), {}, ) + module.deserialize_dynamic_shapes() module.init_finished = True return module + def _reset_stateful_cache(obj: Any) -> None: + """ + Does nothing. Support Huggingface CPU offload hooks. Override the huggingface cache reset function because we don't want the TRT module to be handled by HuggingFace. + """ + return + def recursively_remove_trigger(obj: Any) -> Any: # Not safe: If the object has a circular reference (such as a doubly linkded list), this will cause infinite recursion diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index 346132145e..de0a7b9fdf 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -69,48 +69,16 @@ def __init__(self, compiled_module: torch.nn.Module) -> None: self.old_mode = _PY_RT_CUDAGRAPHS self.compiled_module = compiled_module self.cudagraphs_module: Optional[CudaGraphsTorchTensorRTModule] = None + self.old_module = None - def __enter__(self) -> torch.nn.Module: - global _PY_RT_CUDAGRAPHS - - num_torch_module = 0 - num_trt_module = 0 - for name, module in self.compiled_module.named_children(): - # need to disable cudagraphs if any model requires output allocator - if ( - hasattr(module, "requires_output_allocator") - and module.requires_output_allocator - ): - raise RuntimeError( - "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." - ) - if "_run_on_acc" in name: - num_trt_module += 1 - elif "_run_on_gpu" in name: - num_torch_module += 1 - - if num_torch_module > 0: - # Set whole cudagraphs mode and returns wrapped module - _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS - # Set new mode for C++ - if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: - torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) + def __enter__(self) -> torch.nn.Module | torch.fx.GraphModule: - logger.debug( - "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" - ) - self.cudagraphs_module = CudaGraphsTorchTensorRTModule(self.compiled_module) - return self.cudagraphs_module - else: - if num_trt_module > 0: - logger.debug("No graph breaks detected, using runtime cudagraphs mode") - else: - logger.debug( - "Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode" - ) - # Enable cudagraphs for TRT submodule - set_cudagraphs_mode(True) + if isinstance(self.compiled_module, torch_tensorrt.MutableTorchTensorRTModule): + self.old_module = self.compiled_module.gm + self.compiled_module.gm = get_cuda_graph_module(self.compiled_module.gm) return self.compiled_module + else: + return get_cuda_graph_module(self.compiled_module) def __exit__(self, *args: Any) -> None: # Set cudagraphs back to old mode @@ -118,6 +86,52 @@ def __exit__(self, *args: Any) -> None: # __del__ is not entirely predictable, so we reset cudagraph here if self.cudagraphs_module: self.cudagraphs_module._reset_captured_graph() + if self.old_module: # MutableTorchTRTModule + self.compiled_module.gm = self.old_module + + +def get_cuda_graph_module( + compiled_module: torch.fx.GraphModule, +) -> torch.nn.Module | torch.fx.GraphModule: + global _PY_RT_CUDAGRAPHS + + num_torch_module = 0 + num_trt_module = 0 + for name, module in compiled_module.named_children(): + # need to disable cudagraphs if any model requires output allocator + if ( + hasattr(module, "requires_output_allocator") + and module.requires_output_allocator + ): + raise RuntimeError( + "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." + ) + if "_run_on_acc" in name: + num_trt_module += 1 + elif "_run_on_gpu" in name: + num_torch_module += 1 + + if num_torch_module > 0: + # Set whole cudagraphs mode and returns wrapped module + _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + # Set new mode for C++ + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) + + logger.debug( + "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" + ) + return CudaGraphsTorchTensorRTModule(compiled_module) + else: + if num_trt_module > 0: + logger.debug("No graph breaks detected, using runtime cudagraphs mode") + else: + logger.debug( + "Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode" + ) + # Enable cudagraphs for TRT submodule + set_cudagraphs_mode(True) + return compiled_module def enable_cudagraphs( diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 068ba81473..d59a9482e2 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -101,7 +101,7 @@ def test_refit_one_engine_with_weightmap(): enabled_precisions = {torch.float} debug = False min_block_size = 1 - use_python_runtime = False + use_python_runtime = True exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) @@ -125,6 +125,7 @@ def test_refit_one_engine_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -179,6 +180,7 @@ def test_refit_one_engine_no_map_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -208,7 +210,7 @@ def test_refit_one_engine_with_wrong_weightmap(): enabled_precisions = {torch.float} debug = False min_block_size = 1 - use_python_runtime = False + use_python_runtime = True exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) @@ -237,6 +239,7 @@ def test_refit_one_engine_with_wrong_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -275,7 +278,7 @@ def test_refit_one_engine_bert_with_weightmap(): enabled_precisions = {torch.float} debug = False min_block_size = 1 - use_python_runtime = False + use_python_runtime = True exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) @@ -298,6 +301,7 @@ def test_refit_one_engine_bert_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -324,7 +328,7 @@ def test_refit_one_engine_bert_with_weightmap(): "Refit feature is not supported in Python 3.13 or higher", ) @pytest.mark.unit -def test_refit_one_engine_inline_runtime__with_weightmap(): +def test_refit_one_engine_inline_runtime_with_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -356,6 +360,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -404,6 +409,7 @@ def test_refit_one_engine_python_runtime_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -479,6 +485,7 @@ def forward(self, x): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -599,6 +606,7 @@ def test_refit_one_engine_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -660,6 +668,7 @@ def test_refit_one_engine_bert_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -718,6 +727,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -766,6 +776,7 @@ def test_refit_one_engine_python_runtime_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -841,6 +852,7 @@ def forward(self, x): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -899,6 +911,7 @@ def forward(self, x): ) # Check the output + model.to("cuda") pyt_outputs, trt_outputs = exp_program.module()(*inputs), trt_gm(*inputs) for pyt_output, trt_output in zip(pyt_outputs, trt_outputs): assertions.assertTrue( diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index f1af1098b1..d30f064111 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -79,7 +79,7 @@ def test_check_input_shape_dynamic(): "Refit feature is not supported in Python 3.13 or higher", ) @pytest.mark.unit -def test_model_complex_dynamic_shape(): +def test_model_complex_dynamic_shape_with_saving(): device = "cuda:0" class Model(torch.nn.Module): @@ -115,6 +115,13 @@ def forward(self, a, b, c=None): # Run inference trt_gm(*inputs, **kwargs) + try: + save_path = os.path.join(tempfile.gettempdir(), "mutable_module.pkl") + torch_trt.MutableTorchTensorRTModule.save(mutable_module, save_path) + model = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") + except Exception as e: + assert "Module saving and reloading with dynamic shape failed." + inputs_2 = [torch.rand(10, 9).to(device)] kwargs_2 = { "b": torch.rand(9, 30).to(device), diff --git a/tools/perf/Flux/benchmark.sh b/tools/perf/Flux/benchmark.sh new file mode 100644 index 0000000000..79f5e4b66c --- /dev/null +++ b/tools/perf/Flux/benchmark.sh @@ -0,0 +1,9 @@ +#TODO: Enter the HF Token +huggingface-cli login --token HF_TOKEN + +nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp8_gpu_utilization.txt & +NVIDIA_SMI_PID=$! +python flux_perf.py --dtype fp8 --low_vram_mode> fp8_benchmark.txt +kill $NVIDIA_SMI_PID + + diff --git a/tools/perf/Flux/create_env.sh b/tools/perf/Flux/create_env.sh new file mode 100644 index 0000000000..24470be344 --- /dev/null +++ b/tools/perf/Flux/create_env.sh @@ -0,0 +1,27 @@ +%bash + +git config --global --add safe.directory /home/TensorRT + +#Install bazel +apt install apt-transport-https curl gnupg -y +curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor >bazel-archive-keyring.gpg +mv bazel-archive-keyring.gpg /usr/share/keyrings +echo "deb [arch=amd64 signed-by=/usr/share/keyrings/bazel-archive-keyring.gpg] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list + + +apt update && apt install bazel-8.1.1 +apt install bazel +bazel +cd /home/TensorRT + +python -m pip install --pre -e . --extra-index-url https://download.pytorch.org/whl/nightly/cu128 +pip install tensorrt==10.9.0.34 --force-reinstall + +pip3 install --pre torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + + +pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3" + +pip install notebook +pip install gradio safetensors peft pyinstrument +pip install nvidia-modelopt onnx torchprofile pulp onnxruntime diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py new file mode 100644 index 0000000000..7c3eb56dd1 --- /dev/null +++ b/tools/perf/Flux/flux_perf.py @@ -0,0 +1,61 @@ +import argparse +import os +import sys +from time import time + +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../examples/apps")) +from flux_demo import compile_model, parse_args + + +def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + + print(f"Batch Size: {batch_size}") + print("Time Elapse for", iterations, "iterations:", end - start) + print( + "Average Latency Per Step:", + (end - start) / inference_step / iterations / batch_size, + ) + return image + + +def main(args): + pipe, backbone, trt_gm = compile_model(args) + for batch_size in range(1, args.max_batch_size + 1): + benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" + ) + + parser.add_argument( + "--dtype", + choices=["fp8", "int8", "fp16"], + default="fp16", + help="Select the data type to use (fp8 or int8 or fp16)", + ) + parser.add_argument( + "--low_vram_mode", + action="store_true", + help="Use low VRAM mode when you have a small GPU (<=32GB)", + ) + parser.add_argument( + "--dynamic_shapes", + "-d", + action="store_true", + help="Use dynamic shapes", + ) + parser.add_argument("--max_batch_size", type=int, default=1) + args = parser.parse_args() + main(args)