Skip to content

Commit abf4da0

Browse files
sraikund16pytorchmergebot
authored andcommitted
[Profiler] Induce Inductor Import before Profiling (pytorch#155243)
Fixes pytorch#151829 Summary: Currently, inductor has a lazy init which causes certain aten ops to run during a profiling run. This ends up cluttering the function events especially for smaller traces. One of the attempts to fix this was to simply remove that import from the profiler entirely but it looks like the import happens somewhere downstream anyways and the event still flood our profile. To fix this, we induce the inductor import during prepare trace if the inductor is present. This way regardless of how the workload imports the inductor the actual init process will be done before tracing starts, resulting in more accurate tracing. Test Plan: Added test, also ran N7316820 manually and went from getting many events on the first run to the following output (only difference is Runtime Triggered Module Loading which is CUPTI overhead event): ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls aten::mul_ 1.40% 340.638us 99.92% 24.390ms 24.390ms 1.535us 100.00% 4.605us 4.605us 1 cudaLaunchKernel 0.60% 146.533us 98.52% 24.049ms 24.049ms 0.000us 0.00% 3.070us 3.070us 1 Runtime Triggered Module Loading 6.14% 1.500ms 6.14% 1.500ms 1.500ms 1.535us 100.00% 1.535us 1.535us 1 Runtime Triggered Module Loading 91.78% 22.403ms 91.78% 22.403ms 22.403ms 1.535us 100.00% 1.535us 1.535us 1 void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.535us 100.00% 1.535us 1.535us 1 cudaDeviceSynchronize 0.08% 20.031us 0.08% 20.031us 20.031us 0.000us 0.00% 0.000us 0.000us 1 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls aten::mul_ 82.81% 484.396us 94.26% 551.378us 551.378us 1.440us 100.00% 1.440us 1.440us 1 cudaLaunchKernel 11.45% 66.982us 11.45% 66.982us 66.982us 0.000us 0.00% 0.000us 0.000us 1 void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.440us 100.00% 1.440us 1.440us 1 cudaDeviceSynchronize 5.74% 33.581us 5.74% 33.581us 33.581us 0.000us 0.00% 0.000us 0.000us 1 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Rollback Plan: Differential Revision: D76056511 Pull Request resolved: pytorch#155243 Approved by: https://github.com/ngimel
1 parent f1f49e5 commit abf4da0

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

test/profiler/test_profiler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,6 +2033,19 @@ def test_user_annotation(self):
20332033
else:
20342034
self.assertFalse(evt.is_user_annotation)
20352035

2036+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2037+
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
2038+
def test_basic_profile(self):
2039+
# test a really basic profile to make sure no erroneous aten ops are run
2040+
x = torch.randn(4, device="cuda")
2041+
with torch.profiler.profile(with_stack=True) as p:
2042+
x *= 2
2043+
names = [e.name for e in p.events()]
2044+
for name in names:
2045+
if name.startswith("aten") and name != "aten::mul_":
2046+
self.assertTrue(False, "Found unexpected event: " + name)
2047+
self.assertTrue("aten::mul_" in names)
2048+
20362049
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
20372050
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
20382051
def test_dynamic_toggle(self):

torch/profiler/profiler.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__(
157157
self.acc_events = acc_events
158158
self.custom_trace_id_callback = custom_trace_id_callback
159159
self.profiler: Optional[prof.profile] = None
160+
self.has_cudagraphs = False
160161
self.mem_tl: Optional[MemoryProfileTimeline] = None
161162
self.use_device = None
162163
if ProfilerActivity.CUDA in self.activities:
@@ -181,6 +182,10 @@ def stop(self):
181182
self.stop_trace()
182183

183184
def prepare_trace(self):
185+
if hasattr(torch, "_inductor"):
186+
import torch._inductor.config as inductor_config
187+
188+
self.has_cudagraphs = inductor_config.triton.cudagraphs
184189
if (self.profiler is None) or (not self.acc_events):
185190
self.profiler = prof.profile(
186191
use_cpu=(ProfilerActivity.CPU in self.activities),
@@ -221,26 +226,23 @@ def start_trace(self):
221226
"distributedInfo", json.dumps(dist_info, cls=_NumpyEncoder)
222227
)
223228

224-
if hasattr(torch, "_inductor"):
225-
import torch._inductor.config as inductor_config
226-
227-
cuda_version = None
228-
if hasattr(torch, "version"):
229-
from torch.torch_version import TorchVersion
230-
231-
cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0"))
232-
233-
if inductor_config.triton.cudagraphs and (
234-
(cuda_version and cuda_version < "12.6")
235-
or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()
236-
):
237-
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
238-
self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1")
239-
# FIXME: CUDA Graph does not work well with CUPTI teardown.
240-
# 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
241-
# 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
242-
# Workaround: turn off CUPTI teardown when using CUDA Graphs.
243-
os.environ["TEARDOWN_CUPTI"] = "0"
229+
cuda_version = None
230+
if hasattr(torch, "version"):
231+
from torch.torch_version import TorchVersion
232+
233+
cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0"))
234+
235+
if self.has_cudagraphs and (
236+
(cuda_version and cuda_version < "12.6")
237+
or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()
238+
):
239+
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
240+
self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1")
241+
# FIXME: CUDA Graph does not work well with CUPTI teardown.
242+
# 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
243+
# 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
244+
# Workaround: turn off CUPTI teardown when using CUDA Graphs.
245+
os.environ["TEARDOWN_CUPTI"] = "0"
244246

245247
# Insert the preset user metadata to the trace
246248
for k, v in self.preset_metadata.items():

0 commit comments

Comments
 (0)