From e86c4a2f272390fb80a1da1a3ab3957c92292d2c Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 4 Nov 2025 10:13:52 -0800 Subject: [PATCH] example profiling llama --- .../models/llama3/train_configs/debug_model.toml | 2 +- torchtitan/tools/profiling.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 7760667edd..89b90a1166 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -4,7 +4,7 @@ description = "Llama 3 debug training" print_config = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index f398dba9b5..fcd5dd427e 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -8,11 +8,13 @@ import os import pickle import time +import json import torch from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger +from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 @@ -54,6 +56,20 @@ def trace_handler(prof): f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" ) + with open(output_file) as f: + trace_data = json.load(f) + + begin = time.monotonic() + map_recorded_events_to_aten_ops_with_stack_trace( + trace_data + ) + output_file = os.path.join(curr_trace_dir, f"rank{rank}_trace_augmented.json") + + json.dump(trace_data, open(output_file, 'w')) + logger.info( + f"Augmented dumping profiler traces in {time.monotonic() - begin:.2f} seconds" + ) + logger.info(f"Profiling active. Traces will be saved at {trace_dir}") if not os.path.exists(trace_dir):