Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ description = "Llama 3 debug training"
print_config = false

[profiling]
enable_profiling = false
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
Expand Down
16 changes: 16 additions & 0 deletions torchtitan/tools/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import os
import pickle
import time
import json

import torch

from torchtitan.config import Profiling as ProfilingConfig
from torchtitan.tools.logging import logger
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
Expand Down Expand Up @@ -54,6 +56,20 @@ def trace_handler(prof):
f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds"
)

with open(output_file) as f:
trace_data = json.load(f)

begin = time.monotonic()
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
output_file = os.path.join(curr_trace_dir, f"rank{rank}_trace_augmented.json")

json.dump(trace_data, open(output_file, 'w'))
logger.info(
f"Augmented dumping profiler traces in {time.monotonic() - begin:.2f} seconds"
)

logger.info(f"Profiling active. Traces will be saved at {trace_dir}")

if not os.path.exists(trace_dir):
Expand Down
Loading