Skip to content

feat: Hierarchical Partitioner to support multi-backends #3539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions examples/hierarchical_partitioner_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
import torch.nn as nn
import torch_tensorrt
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_ATEN_CONVERTERS,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
pre_export_lowering,
)
from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition
from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import (
hierarchical_adjacency_partition,
)


class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(128)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = torch.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = torch.relu(x)
return x


def main():
# Create model
model = SimpleModel().cuda()
# model = models.efficientnet_b0(pretrained=True).cuda()
model = model.eval()

# Create example input
example_input = torch.randn(1, 3, 224, 224).cuda()

exported_program = torch.export.export(model, (example_input,))
exported_program = pre_export_lowering(exported_program)
exported_program = exported_program.run_decompositions(get_decompositions())

gm = exported_program.module()

print(gm.graph)

original_output = model(example_input)

# Partition the model using the adjacency partitioner
# partitioned_model, op_support = partition(
# gm,
# verbose=True,
# min_block_size=1,
# torch_executed_ops=[
# torch.ops.aten.relu.default,
# ],
# )

partitioned_model, op_support = hierarchical_adjacency_partition(
gm,
verbose=True,
min_block_size=1,
backend_priority=["inductor", "tensorrt"],
backend_support_map={
"inductor": {
# operator.getitem,
torch.ops.aten.conv2d.default,
torch.ops.aten.convolution.default,
},
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
},
torch_executed_ops=[
torch.ops.aten._native_batch_norm_legit_no_training.default
],
require_full_compilation=False,
skip_fusion=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_fusion=False slows down the partitioning a lot. Can you check if it's really needed ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the min_block_size and torch_executed_ops need to be re-thought or deprecated as the min_block_size can apply to all backends

My understanding is that, if the num of ops of a GM is less than min_block_size, no matter it's TRT or other backend, the GM would not be compiled. Do you mean it shouldn't apply to other backends?

torch_executed_ops will be replaced by backend_support_map

My understanding is that torch execute is not considered as a backend because it doesn't need any compilation, it just runs ops in eager mode. So, if an op was in torch_executed_ops, it would ignore backend_support_map and run in torch eager anyway.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_fusion=False slows down the partitioning a lot. Can you check if it's really needed ?

Since adjacency partitioner uses this flag, I just keep it here. yeah I can definitely switch it to True in the example.

)

print("\nPartitioned Model Structure:")
print(partitioned_model)

print("0. Original_output:", original_output)

with torch.no_grad():
partitioned_output = partitioned_model(example_input)
print("1. Partitioned output:", partitioned_output)
print(
"Partitioned output == Original output:",
torch.allclose(original_output, partitioned_output, 1e-2, 1e-2),
)

compiled_model = torch_tensorrt.compile(
model, inputs=[example_input], min_block_size=1
)
with torch.no_grad():
compiled_output = compiled_model(example_input)
print("2. Compiled_output:", compiled_output)

print(
"Compiled output == Original output:",
torch.allclose(original_output, compiled_output, 1e-2, 1e-2),
)


if __name__ == "__main__":
main()
120 changes: 98 additions & 22 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
interpret_module_to_result,
repair_double_inputs,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_ATEN_CONVERTERS,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
Expand Down Expand Up @@ -788,20 +791,49 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
"Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
)

############ TODO: testing only ############
use_hierarchical_partitioner = False
backend_priority = ["inductor", "tensorrt"]
backend_support_map = {
"inductor": {
# operator.getitem,
torch.ops.aten.conv2d.default,
torch.ops.aten.convolution.default,
},
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
}
#############################################
# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False
# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
logger.info("Partitioning the graph via the fast partitioner")
partitioned_module, supported_ops = partitioning.fast_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
require_full_compilation=settings.require_full_compilation,
skip_fusion=(num_supported_ops == total_ops),
)
if use_hierarchical_partitioner:
logger.info(
"Partitioning the graph via the fast hierarchical partitioner"
)
partitioned_module, supported_ops = (
partitioning.hierarchical_adjacency_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
require_full_compilation=settings.require_full_compilation,
skip_fusion=(num_supported_ops == total_ops),
backend_priority=backend_priority,
backend_support_map=backend_support_map,
)
)
else:
logger.info("Partitioning the graph via the fast partitioner")
partitioned_module, supported_ops = partitioning.fast_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
require_full_compilation=settings.require_full_compilation,
skip_fusion=(num_supported_ops == total_ops),
)

except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
logger.error(
Expand Down Expand Up @@ -836,7 +868,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
submodule_node_dict[node.name] = node

# Store TRT replicas of Torch subgraphs
trt_modules = {}
compiled_modules = {}
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those

Expand Down Expand Up @@ -913,26 +945,61 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
dryrun_tracker.tensorrt_graph_count += 1
dryrun_tracker.per_subgraph_data.append(subgraph_data)

# Create TRT engines from submodule
# Create TRT engines / compiled models from submodule
# torch._logging.set_logs(inductor=logging.DEBUG)
if not settings.dryrun:
trt_module = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
engine_cache=engine_cache,
)
if use_hierarchical_partitioner:
# compile submodule with pytorch inductor
if "_run_on_acc_inductor" in name:
sub_inputs = []
for input in submodule_inputs:
sub_input = (
torch.randn(input.shape)
.to(dtype.to(input.dtype, t=torch.dtype))
.cuda()
)
sub_inputs.append(sub_input)

compiled_func = torch._inductor.compile(
submodule,
sub_inputs,
)
Comment on lines +956 to +966
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.dynamo.mark_dynamic API is used to set dynamic shapes for torch.compile workflow. Reference: https://docs.pytorch.org/TensorRT/user_guide/dynamic_shapes.html. So you can use the construct_submodule_inputs() API to give you dynamic inputs (if that's the case) and set them to the inductor segment accordingly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

submodule_inputs is already the return of construct_submodule_inputs(). What I did here is to convert torch-trt Input to torch Tensor. I changed to:

sub_input = (
    input.torch_tensor
    .to(dtype.to(input.dtype, t=torch.dtype))
    .cuda()
)

# Wrap the compiled function to be a torch.nn.Module
compiled_submodule = FunctionWrapper(compiled_func)

elif "_run_on_acc_tensorrt" in name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some sort of design where the capability and conversion parts can be grouped and registered? We can add this concept to the RFC for later

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That way we dont have a long conditional case set we just look up the appropriate conversion function on a standardized API

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I agree that the conversion part of different backends should be grouped, but for now I don't have too much info about other backends (like how to convert an op to that backend). We can definitely do this when we are ready to support other backends.

compiled_submodule = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
engine_cache=engine_cache,
)
else:
raise ValueError(f"Unknown backend for submodule: {name}")
Comment on lines +978 to +979
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should there be a _run_on_gpu segment here ? How would the torch.ops.aten._native_batch_norm_legit_no_training.default fallback to native pytorch in your example from slack ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, as I mentioned above, _run_on_gpu is not considered as a backend. We just keep the module as is. Only _run_on_acc_backend modules need to be replaced with the compiled module. This aligns with our existing implementation.

else:
compiled_submodule = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
engine_cache=engine_cache,
)

trt_modules[name] = trt_module
compiled_modules[name] = compiled_submodule

# Parse the graph I/O and store it in dryrun tracker
parse_graph_io(gm, dryrun_tracker)

# Replace all FX Modules with TRT Modules
for name, trt_module in trt_modules.items():
setattr(partitioned_module, name, trt_module)
for name, compiled_module in compiled_modules.items():
setattr(partitioned_module, name, compiled_module)
if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows:
getattr(partitioned_module, name).setup_engine()
if use_hierarchical_partitioner:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here, we should standardize post processing as well

if "_run_on_acc_tensorrt" in name:
getattr(partitioned_module, name).setup_engine()
else:
getattr(partitioned_module, name).setup_engine()

# Reset settings object to user specification after fallback to global partitioning mode
if fast_partitioner_failed:
Expand Down Expand Up @@ -1276,3 +1343,12 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any:
)

return replace_execute_engine_no_op_node(exp_program)


class FunctionWrapper(torch.nn.Module):
def __init__(self, func):
super().__init__()
self.func = func

def forward(self, *args, **kwargs):
return self.func(*args, **kwargs)
Comment on lines +1348 to +1354
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider naming this to InductorModule and moving to utils

1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/partitioning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._adjacency_partitioner import partition as fast_partition
from ._global_partitioner import partition as global_partition
from ._hierarchical_partitioner import hierarchical_adjacency_partition
from .common import (
construct_submodule_inputs,
get_graph_converter_support,
Expand Down
Loading
Loading