Skip to content

fix: replace add_identity by add_cast for type cast #3563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

junstar92
Copy link
Contributor

@junstar92 junstar92 commented Jun 9, 2025

Description

This PR updates the type_cast helper function to ensure compatibility with TensorRT's strongly typed network mode.

type_cast used add_identity() followed by set_output_type() to perform the data type cast. However, in strongly typed mode, calling set_output_type() on the identity layer causes an error below:

ILayer::setOutputType: Error Code 3: API Usage Error (Parameter check failed, condition: !mNetwork->usingStronglyTyped(). INetworkLayer::setOutputType cannot be called for a strongly typed network.)
[graphShapeAnalyzer.cpp::checkCalculationStatusSanity::1962] Error Code 2: Internal Error (Assertion !isInFlight(p.second.symbolicRep) failed. )

type_cast is called by expand function in torch_tensorrt/dynamo/conversion/impl/slice/ops.py with dynamic dimension index.

input_t = prepend_ones(
ctx.net,
input_t,
name + "_expand_broadcast",
shape_rank - initial_tensor_rank,
)

The following code snippet reproduces the error:

import torch
import torch_tensorrt
from torch.export._trace import _export
from torch_tensorrt.dynamo._compiler import CompilationSettings
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.lowering import get_decompositions


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.visual = torch.nn.Linear(10, 10)

    def forward(self, input: torch.Tensor):
        return input.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0)


model = Model().to("cuda")
x = torch.randn(1, 40).to("cuda")
ep = _export(model, (x,))
ep = ep.run_decompositions(get_decompositions(False))
gm = ep.module()


interpreter = TRTInterpreter(
    gm,
    [torch_tensorrt.Input(name="input", min_shape=(1, 40), opt_shape=(4, 40), max_shape=(8, 40), dtype=torch.float32)],
    compilation_settings=CompilationSettings(use_explicit_typing=True),
)
results = interpreter.run()

To address this, the function now uses add_cast() to explicitly insert a cast layer that converts the input tensor to the desired cast_type.

If there was a specific reason for using add_identity(), please let me know, as this change assumes that the identity layer was not essential beyond type casting.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@peri044
Copy link
Collaborator

peri044 commented Jun 9, 2025

Thanks @junstar92 for the contribution. Instead of modifying the FX path, we should import these utilities from the dynamo path since it is actively being developed. So, instead can you modify this change so that the prepend_ones is imported from dynamo/conversion/converter_utils instead ?

from torch_tensorrt.dynamo.converters.converter_utils import (
    has_dynamic_shape,
    prepend_ones,
    set_layer_name,
)

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@junstar92 Thanks for your contribution! As @peri044 mentioned, we have switched our attention to Dynamo path. In this PR, instead of importing from fx, can you change

from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
prepend_ones,
set_layer_name,
)

to

from torch_tensorrt.dynamo.conversion.converter_utils import (
    has_dynamic_shape,
    prepend_ones,
    set_layer_name,
)

and change


to ctx accordingly?

Besides, I noticed that you are using from torch.export._trace import _export instead of from torch.export import export in your repro. May I know the reason?

@apbose
Copy link
Collaborator

apbose commented Jun 9, 2025

LGTM apart from the changes mentioned above

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jun 10, 2025
@junstar92
Copy link
Contributor Author

@peri044 @zewenli98 Thanks for the suggestion. As you mentioned, I changed fx's conversion utilities to dynamo's.

@junstar92
Copy link
Contributor Author

@zewenli98

Besides, I noticed that you are using from torch.export._trace import _export instead of from torch.export import export in your repro. May I know the reason?

There's no special reason, it's just how I've been doing it.

@@ -909,7 +909,6 @@ def type_cast(
"""
This function helps to cast the input type to cast_type
"""
layer_i = network.add_identity(input)
layer_i.set_output_type(0, cast_type)
layer_i = network.add_cast(input, cast_type)
set_layer_name(layer_i, target, f"{name}_dtype_change")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick change @junstar92. LGTM as such. Just a minor change, since now we use the cast_trt_tensor in py/torch_tensorrt/dynamo/conversion/converter_utils.py and the above change is related to that, you could change the comment there -

Adds an Identity layer to the network which performs the conversion
if the input's dtype is different from the cast type. Otherwise returns
input unchanged

to something like

Adds a Cast layer to the network to convert the input tensor to the specified dtype.

If the input tensor already has the desired dtype, it is returned unchanged.
Otherwise, a Cast layer is added to perform the conversion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. I updated the comment for cast_trt_tensor as you mentioned.

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 minor comment. mostly looks good

@@ -909,7 +909,6 @@ def type_cast(
"""
This function helps to cast the input type to cast_type
"""
layer_i = network.add_identity(input)
layer_i.set_output_type(0, cast_type)
layer_i = network.add_cast(input, cast_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the cast_trt_tensor function to this instead ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a patch for FX, but looks like cast_trt_tensor is only in dynamo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peri044 As @zewenli98 mentioned, cast_trt_tensor is in Dynamo path. So it needs to import dynamo.conversion.converter_utils in FX path. It this what you intended? If not, would you prefer me to implement cast_trt_tensor just like in Dynamo path and use it instead of type_cast?

@peri044
Copy link
Collaborator

peri044 commented Jun 13, 2025

Also, @junstar92 please rebase with main. Some of the CI failures should be resolved

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: fx fx
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants