Skip to content

WIP Draft: DINOv2 supports dynamic shapes in onnx or tensorrt format #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: export-to-onnx-fixes
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions scripts/convert-to-onnx-dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""DINOV2 model converter to onnx."""
import torch
import argparse
import os
import sys
from pathlib import Path
current_path = Path(__file__).resolve()
parent_path = current_path.parent.parent.as_posix()
sys.path.insert(0, parent_path)
import hubconf
import tensorrt as trt



class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, tensor):
ff = self.model(tensor)
return ff

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="dinov2_vits14", help="dinov2 model name")
parser.add_argument(
"--image_height", type=int, default=518, help="input image height, must be a multiple of patch_size"
)
parser.add_argument(
"--image_width", type=int, default=518, help="input image height, must be a multiple of patch_size"
)
parser.add_argument(
"--patch_size", type=int, default=14, help="dinov2 model patch size, default is 16"
)
args = parser.parse_args()


if __name__ == "__main__":

assert args.image_height % args.patch_size == 0, f"image height must be multiple of {args.patch_size}, but got {args.image_height}"
assert args.image_width % args.patch_size == 0, f"image width must be multiple of {args.patch_size}, but got {args.image_height}"

model = Wrapper(hubconf.dinov2_vits14(for_onnx=True)).to("cpu")
model.eval()

dummy_input = torch.rand([1, 3, args.image_height, args.image_width]).to("cpu")
dummy_output = model(dummy_input)

torch.onnx.export(
model,
#scripted_model,
dummy_input,
"dinov2_dynamic.onnx",
export_params=True,
opset_version=12,
do_constant_folding=True,
training=torch.onnx.TrainingMode.EVAL,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 3: "width"} # Dynamic batch size and width in input
#"input": {3: "width"} # Dynamic width only
#"input": {0: "batch_size", 2: "height", 3: "width"}
}
)

logger = trt.Logger(trt.Logger.VERBOSE)

onnx_model_path = "dinov2_dynamic.onnx"


## NOTE: onnx to tensorRT engine, failed attempt 1
# Create builder and network
#with open(onnx_model_path, "rb") as f, trt.Builder(logger) as builder, builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, trt.OnnxParser(network, logger) as parser:
#
# # Parse ONNX model
# if not parser.parse(f.read()):
# print("Failed to parse the ONNX file.")
# for error in range(parser.num_errors):
# print(parser.get_error(error))
# raise ValueError("ONNX parsing failed.")
#
# # Create optimization profile for dynamic shapes
# config = builder.create_builder_config()
# profile = builder.create_optimization_profile()
#
# # Specify dynamic shape range for input with fixed height and dynamic width
# profile.set_dynamic_shape_profile("input", min=(1, 3, 518, 518), opt=(1, 3, 518, 518), max=(1, 3, 518, 1022)) # 1022, 1554
#
# config.add_optimization_profile(profile)
# config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
#
#
# # Build tensorRT engine
# #engine = builder.build_engine(network, config)
# engine = builder.build_serialized_network(network, config)
#
# # Save the engine to a file
# with open("dinov2_dynamic_trt.engine", "wb") as engine_file:
# engine_file.write(engine.serialize())
53 changes: 53 additions & 0 deletions scripts/convert-via-torch2trt-dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from torch2trt_dynamic import module2trt, BuildEngineConfig
import torch


from typing import Literal

MODEL_NAMES = {
'giant': 'dinov2_vitg14',
'large': 'dinov2_vitl14',
'base': 'dinov2_vitb14',
'small': 'dinov2_vits14',
}

dev = torch.device('cuda')

class Dino(torch.nn.Module):
def __init__(self, model_size: Literal['small', 'base', 'large', 'giant']) -> None:
super(Dino, self).__init__()
self.model = torch.hub.load('facebookresearch/dinov2', MODEL_NAMES[model_size]).to(dev).eval()
# to imply DINOv2 to onnx fixes, load the model from repo that includes commits in 'RRoundTable/dinov2/tree/dinov2-onnx'
# Or simply load from current repo from (hubconf.dinov2_vitb14(for_onnx=True)))
# self.model = torch.hub.load('RRoundTable/dinov2:dinov2-onnx', MODEL_NAMES[model_size]).to(dev).eval()

def forward(self, x):
return self.model.get_intermediate_layers(x)[0]


model = Dino(model_size='base') # Load and move model to CUDA

# create example data
x = torch.ones((1, 3, 518, 518)).cuda()

# convert to TensorRT feeding sample data as input
config = BuildEngineConfig(
shape_ranges=dict(
x=dict(
min=(1, 3, 518, 518),
opt=(2, 3, 518, 742),
max=(4, 3, 518, 1022),
)
))
trt_model = module2trt(
model,
args=[x],
config=config)

x = torch.rand(1, 3, 518, 518).cuda()
with torch.no_grad():
y = model(x)
y_trt = trt_model(x)

# check the output against PyTorch
torch.testing.assert_close(y, y_trt)
47 changes: 47 additions & 0 deletions scripts/test-onnx-dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import onnxruntime as ort
import numpy as np
import torch


def test_onnx_model(onnx_model_path, patch_size, test_shapes):
"""
Loads an ONNX model and tests it with multiple input shapes.

:param onnx_model_path: Path to the ONNX model file
:param patch_size: Patch size to ensure height and width are multiples
:param test_shapes: List of (height, width) tuples to test
"""
# Load ONNX model
session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

for h, w in test_shapes:
#assert h % patch_size == 0 and w % patch_size == 0, f"Height and width must be multiples of {patch_size}"
if not((h % patch_size == 0) and (w % patch_size == 0)):
raise Warning(f"Height and width must be multiples of {patch_size}, currently got {h} and {w}")
continue
# Create a random input tensor
input_tensor = np.random.rand(1, 3, h, w).astype(np.float32)

# Run inference
outputs = session.run([output_name], {input_name: input_tensor})

print(f"Tested shape: (1, 3, {h}, {w}) -> Output shape: {outputs[0].shape}")

if __name__ == "__main__":
onnx_model_path = "dinov2_dynamic.onnx" # Path to your ONNX model
patch_size = 14 # Ensure height/width are multiples of this

# Define test shapes (ensure they match TensorRT profile)
test_shapes = [
(518, 518), # Minimum shape
#(518, 742), # Mid-range shape
#(518, 1022), # Maximum shape (as per TensorRT profile)
(518, 518),
(518, 518),
]

# NOTE: Dynamic shape range is not supported by ONNXRuntime, it throws: "element_wise_ops.h:560 void onnxruntime::BroadcastIterator::Append(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1."

test_onnx_model(onnx_model_path, patch_size, test_shapes)
50 changes: 50 additions & 0 deletions scripts/test-trtengine-from-onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
#import torch
import pycuda.driver as cuda
#import pycuda.autoinit
import tensorrt as trt
import os

# NOTE: meant to test trt engine created via onnx dynamic, yet the exported onnx model does not support dynamic shape range
# This file could be though of as pseudo code for testing dynamic shape range in TensorRT

def load_engine(engine_path):
logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, "rb") as f, trt.Runtime(logger) as runtime:
return runtime.deserialize_cuda_engine(f.read())

current_directory = os.getcwd()
print(f"Current working directory: {current_directory}")
engine = load_engine("/scripts/dinov2_dynamic_trt.engine")
context = engine.create_execution_context()

stream = cuda.Stream()
batch_size = 1
height = 518
width = np.random.randint(518, 1554) # Random test input with dynamic width within the new range [518, 1554]


input_shape = (batch_size, 3, height, 518)
output_shape = (batch_size, width, 768)
d_input = cuda.mem_alloc(np.prod(input_shape) * np.float32().itemsize)
d_output = cuda.mem_alloc(np.prod(output_shape) * np.float32().itemsize)


input_data = np.random.randn(batch_size, 3, height, width).astype(np.float32)

# Adjust context for the new width
context.set_binding_shape(0, (batch_size, 3, height, width))

# Transfer data to GPU
cuda.memcpy_htod(d_input, input_data)

# Run inference
bindings = [int(d_input), int(d_output)]
context.execute_v2(bindings)

# Copy the output back to CPU
output_data = np.empty(context.get_binding_shape(1), dtype=np.float32)
cuda.memcpy_dtoh(output_data, d_output)

print(f"Inference output shape: {output_data.shape}")
print("Inference completed successfully.")