Skip to content

Commit fb5dc81

Browse files
committed
Fixed the comments
1 parent 5683644 commit fb5dc81

File tree

10 files changed

+61
-97
lines changed

10 files changed

+61
-97
lines changed

core/runtime/TRTEngineProfiler.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ namespace runtime {
1212

1313
enum TraceFormat { kPERFETTO, kTREX };
1414

15-
// Forward declare the function
16-
1715
struct TRTEngineProfiler : public nvinfer1::IProfiler {
1816
struct Record {
1917
float time{0};

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def cross_compile_for_windows(
6666
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
6767
] = _defaults.ENABLED_PRECISIONS,
6868
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
69-
debug: bool = False,
7069
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
7170
workspace_size: int = _defaults.WORKSPACE_SIZE,
7271
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -140,7 +139,6 @@ def cross_compile_for_windows(
140139
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
141140
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
142141
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
143-
debug (bool): Enable debuggable engine
144142
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
145143
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
146144
workspace_size (int): Maximum size of workspace given to TensorRT
@@ -187,9 +185,9 @@ def cross_compile_for_windows(
187185
f"Cross compile for windows is only supported on x86-64 Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}"
188186
)
189187

190-
if debug:
188+
if kwargs.get("debug", False):
191189
warnings.warn(
192-
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.",
190+
"`debug` is deprecated. Please use with torch_tensorrt.dynamo.Debugger(...) to wrap your compilation call to enable debugging functionality.",
193191
DeprecationWarning,
194192
stacklevel=2,
195193
)
@@ -404,7 +402,6 @@ def compile(
404402
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
405403
] = _defaults.ENABLED_PRECISIONS,
406404
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
407-
debug: bool = False,
408405
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
409406
workspace_size: int = _defaults.WORKSPACE_SIZE,
410407
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -480,7 +477,6 @@ def compile(
480477
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
481478
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
482479
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
483-
debug (bool): Enable debuggable engine
484480
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
485481
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
486482
workspace_size (int): Maximum size of workspace given to TensorRT
@@ -523,9 +519,9 @@ def compile(
523519
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
524520
"""
525521

526-
if debug:
522+
if kwargs.get("debug", False):
527523
warnings.warn(
528-
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` for debugging functionality",
524+
"`debug` is deprecated. Please use with torch_tensorrt.dynamo.Debugger(...) to wrap your compilation call to enable debugging functionality",
529525
DeprecationWarning,
530526
stacklevel=2,
531527
)
@@ -732,7 +728,7 @@ def compile_module(
732728
settings: CompilationSettings = CompilationSettings(),
733729
engine_cache: Optional[BaseEngineCache] = None,
734730
*,
735-
_debugger_settings: Optional[DebuggerConfig] = None,
731+
_debugger_config: Optional[DebuggerConfig] = None,
736732
) -> torch.fx.GraphModule:
737733
"""Compile a traced FX module
738734
@@ -935,29 +931,30 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
935931

936932
trt_modules[name] = trt_module
937933

938-
if _debugger_settings:
934+
if _debugger_config:
939935

940-
if _debugger_settings.save_engine_profile:
936+
if _debugger_config.save_engine_profile:
941937
if settings.use_python_runtime:
942-
if _debugger_settings.profile_format == "trex":
938+
if _debugger_config.profile_format == "trex":
943939
logger.warning(
944940
"Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization."
945941
)
946942
trt_module.enable_profiling()
947943
else:
948944
path = os.path.join(
949-
_debugger_settings.logging_dir, "engine_visualization"
945+
_debugger_config.logging_dir,
946+
"engine_visualization_profile",
950947
)
951948
os.makedirs(path, exist_ok=True)
952949
trt_module.enable_profiling(
953950
profiling_results_dir=path,
954-
profile_format=_debugger_settings.profile_format,
951+
profile_format=_debugger_config.profile_format,
955952
)
956953

957-
if _debugger_settings.save_layer_info:
954+
if _debugger_config.save_layer_info:
958955
with open(
959956
os.path.join(
960-
_debugger_settings.logging_dir, "engine_layer_info.json"
957+
_debugger_config.logging_dir, "engine_layer_info.json"
961958
),
962959
"w",
963960
) as f:
@@ -990,7 +987,6 @@ def convert_exported_program_to_serialized_trt_engine(
990987
enabled_precisions: (
991988
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
992989
) = _defaults.ENABLED_PRECISIONS,
993-
debug: bool = False,
994990
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
995991
workspace_size: int = _defaults.WORKSPACE_SIZE,
996992
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
@@ -1052,7 +1048,6 @@ def convert_exported_program_to_serialized_trt_engine(
10521048
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
10531049
]
10541050
enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use
1055-
debug (bool): Whether to print out verbose debugging information
10561051
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
10571052
min_block_size (int): Minimum number of operators per TRT-Engine Block
10581053
torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage
@@ -1092,9 +1087,9 @@ def convert_exported_program_to_serialized_trt_engine(
10921087
Returns:
10931088
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10941089
"""
1095-
if debug:
1090+
if kwargs.get("debug", False):
10961091
warnings.warn(
1097-
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.",
1092+
"`debug` is deprecated. Please use with torch_tensorrt.dynamo.Debugger(...) to wrap your compilation call to enable debugging functionality.",
10981093
DeprecationWarning,
10991094
stacklevel=2,
11001095
)
@@ -1181,7 +1176,6 @@ def convert_exported_program_to_serialized_trt_engine(
11811176
compilation_options = {
11821177
"assume_dynamic_shape_support": assume_dynamic_shape_support,
11831178
"enabled_precisions": enabled_precisions,
1184-
"debug": debug,
11851179
"workspace_size": workspace_size,
11861180
"min_block_size": min_block_size,
11871181
"torch_executed_ops": torch_executed_ops,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
L2_LIMIT_FOR_TILING = -1
5050
USE_DISTRIBUTED_MODE_TRACE = False
5151
OFFLOAD_MODULE_TO_CPU = False
52+
DEBUG_LOGGING_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt/debug_logs")
5253

5354

5455
def default_device() -> Device:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@
4545
get_trt_tensor,
4646
to_torch,
4747
)
48-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
4948
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
5049
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
50+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
5151
from torch_tensorrt.fx.observer import Observer
5252
from torch_tensorrt.logging import TRT_LOGGER
5353

@@ -82,13 +82,13 @@ def __init__(
8282
compilation_settings: CompilationSettings = CompilationSettings(),
8383
engine_cache: Optional[BaseEngineCache] = None,
8484
*,
85-
_debugger_settings: Optional[DebuggerConfig] = None,
85+
_debugger_config: Optional[DebuggerConfig] = None,
8686
):
8787
super().__init__(module)
8888

8989
self.logger = TRT_LOGGER
9090
self.builder = trt.Builder(self.logger)
91-
self._debugger_settings = _debugger_settings
91+
self._debugger_config = _debugger_config
9292
flag = 0
9393
if compilation_settings.use_explicit_typing:
9494
STRONGLY_TYPED = 1 << (int)(
@@ -209,7 +209,7 @@ def _populate_trt_builder_config(
209209
) -> trt.IBuilderConfig:
210210
builder_config = self.builder.create_builder_config()
211211

212-
if self._debugger_settings and self._debugger_settings.engine_builder_monitor:
212+
if self._debugger_config and self._debugger_config.engine_builder_monitor:
213213
builder_config.progress_monitor = TRTBulderMonitor()
214214

215215
if self.compilation_settings.workspace_size != 0:
@@ -220,8 +220,7 @@ def _populate_trt_builder_config(
220220
if version.parse(trt.__version__) >= version.parse("8.2"):
221221
builder_config.profiling_verbosity = (
222222
trt.ProfilingVerbosity.DETAILED
223-
if self._debugger_settings
224-
and self._debugger_settings.save_engine_profile
223+
if self._debugger_config and self._debugger_config.save_engine_profile
225224
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
226225
)
227226

py/torch_tensorrt/dynamo/debug/_Debugger.py

Lines changed: 20 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from unittest import mock
99

1010
import torch
11+
from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR
1112
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
1213
from torch_tensorrt.dynamo.debug._supports_debugger import (
1314
_DEBUG_ENABLED_CLS,
@@ -32,7 +33,7 @@ def __init__(
3233
save_engine_profile: bool = False,
3334
profile_format: str = "perfetto",
3435
engine_builder_monitor: bool = True,
35-
logging_dir: str = tempfile.gettempdir(),
36+
logging_dir: str = DEBUG_LOGGING_DIR,
3637
save_layer_info: bool = False,
3738
):
3839
"""Initialize a debugger for TensorRT conversion.
@@ -92,7 +93,7 @@ def __init__(
9293
def __enter__(self) -> None:
9394
self.original_lvl = _LOGGER.getEffectiveLevel()
9495
self.rt_level = torch.ops.tensorrt.get_logging_level()
95-
dictConfig(self.get_customized_logging_config())
96+
dictConfig(self.get_logging_config(self.log_level))
9697

9798
if self.capture_fx_graph_before or self.capture_fx_graph_after:
9899
self.old_pre_passes, self.old_post_passes = (
@@ -126,22 +127,22 @@ def __enter__(self) -> None:
126127
self._context_stack = contextlib.ExitStack()
127128

128129
for f in _DEBUG_ENABLED_FUNCS:
129-
f.__kwdefaults__["_debugger_settings"] = self.cfg
130+
f.__kwdefaults__["_debugger_config"] = self.cfg
130131

131132
[
132133
self._context_stack.enter_context(
133134
mock.patch.object(
134135
c,
135136
"__init__",
136-
functools.partialmethod(c.__init__, _debugger_settings=self.cfg),
137+
functools.partialmethod(c.__init__, _debugger_config=self.cfg),
137138
)
138139
)
139140
for c in _DEBUG_ENABLED_CLS
140141
]
141142

142143
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
143144

144-
dictConfig(self.get_default_logging_config())
145+
dictConfig(self.get_logging_config(None))
145146
torch.ops.tensorrt.set_logging_level(self.rt_level)
146147
if self.capture_fx_graph_before or self.capture_fx_graph_after:
147148
ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
@@ -151,50 +152,13 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
151152
self.debug_file_dir = tempfile.TemporaryDirectory().name
152153

153154
for f in _DEBUG_ENABLED_FUNCS:
154-
f.__kwdefaults__["_debugger_settings"] = None
155+
f.__kwdefaults__["_debugger_config"] = None
155156

156157
self._context_stack.close()
157158

158-
def get_customized_logging_config(self) -> dict[str, Any]:
159-
config = {
160-
"version": 1,
161-
"disable_existing_loggers": False,
162-
"formatters": {
163-
"brief": {
164-
"format": "%(asctime)s - %(levelname)s - %(message)s",
165-
"datefmt": "%H:%M:%S",
166-
},
167-
"standard": {
168-
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
169-
"datefmt": "%Y-%m-%d %H:%M:%S",
170-
},
171-
},
172-
"handlers": {
173-
"file": {
174-
"level": self.log_level,
175-
"class": "logging.FileHandler",
176-
"filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log",
177-
"formatter": "standard",
178-
},
179-
"console": {
180-
"level": self.log_level,
181-
"class": "logging.StreamHandler",
182-
"formatter": "brief",
183-
},
184-
},
185-
"loggers": {
186-
"": { # root logger
187-
"handlers": ["file", "console"],
188-
"level": self.log_level,
189-
"propagate": True,
190-
},
191-
},
192-
"force": True,
193-
}
194-
return config
195-
196-
def get_default_logging_config(self) -> dict[str, Any]:
197-
config = {
159+
def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]:
160+
level = log_level if log_level is not None else self.original_lvl
161+
config: dict[str, Any] = {
198162
"version": 1,
199163
"disable_existing_loggers": False,
200164
"formatters": {
@@ -209,18 +173,26 @@ def get_default_logging_config(self) -> dict[str, Any]:
209173
},
210174
"handlers": {
211175
"console": {
212-
"level": self.original_lvl,
176+
"level": level,
213177
"class": "logging.StreamHandler",
214178
"formatter": "brief",
215179
},
216180
},
217181
"loggers": {
218182
"": { # root logger
219183
"handlers": ["console"],
220-
"level": self.original_lvl,
184+
"level": level,
221185
"propagate": True,
222186
},
223187
},
224188
"force": True,
225189
}
190+
if log_level is not None:
191+
config["handlers"]["file"] = {
192+
"level": level,
193+
"class": "logging.FileHandler",
194+
"filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log",
195+
"formatter": "standard",
196+
}
197+
config["loggers"][""]["handlers"].append("file")
226198
return config
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
import tempfile
21
from dataclasses import dataclass
32

3+
from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR
4+
45

56
@dataclass
67
class DebuggerConfig:
78
log_level: str = "debug"
89
save_engine_profile: bool = False
910
engine_builder_monitor: bool = True
10-
logging_dir: str = tempfile.gettempdir()
11+
logging_dir: str = DEBUG_LOGGING_DIR
1112
profile_format: str = "perfetto"
1213
save_layer_info: bool = False

py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
2-
import tempfile
32
from typing import Any, Callable, List, Optional
43

54
import torch
65
from torch.fx import passes
76
from torch.fx.passes.pass_manager import PassManager
7+
from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR
88
from torch_tensorrt.dynamo._settings import CompilationSettings
99

1010

@@ -70,7 +70,7 @@ def remove_pass_with_index(self, index: int) -> None:
7070
del self.passes[index]
7171

7272
def insert_debug_pass_before(
73-
self, passes: List[str], output_path_prefix: str = tempfile.gettempdir()
73+
self, passes: List[str], output_path_prefix: str = DEBUG_LOGGING_DIR
7474
) -> None:
7575
"""Insert debug passes in the PassManager pass sequence prior to the execution of a particular pass.
7676
@@ -96,7 +96,7 @@ def insert_debug_pass_before(
9696
self._validated = False
9797

9898
def insert_debug_pass_after(
99-
self, passes: List[str], output_path_prefix: str = tempfile.gettempdir()
99+
self, passes: List[str], output_path_prefix: str = DEBUG_LOGGING_DIR
100100
) -> None:
101101
"""Insert debug passes in the PassManager pass sequence after the execution of a particular pass.
102102

0 commit comments

Comments
 (0)