66
77import copy
88import io
9+ import os
910import random
1011from contextlib import nullcontext , redirect_stdout
1112from dataclasses import dataclass , field
12- from pathlib import Path
13+ import pathlib
1314from typing import Callable , Optional
1415
1516import fire
1617import pandas as pd
1718
19+ # disable inductor FX cache, so we can can always study the inductor output logs
20+ os .environ ['TORCHINDUCTOR_FORCE_DISABLE_CACHES' ] = '1'
21+
1822import torch
1923import torch .nn as nn
2024import torch .nn .functional as F
3034 parse_bw_and_kernel_name ,
3135 profiler_output_to_gpu_time_for_key ,
3236 profiler_output_to_filtered_time_by_kernel_name ,
37+ update_triton_kernels_in_prof_chome_trace_with_torch_logs ,
3338)
3439
3540# don't truncate long kernel names
3843pd .set_option ("display.float_format" , "{:.3f}" .format )
3944
4045
46+
4147class LNLinear (torch .nn .Module ):
4248 def __init__ (self , fc_dim1 , fc_dim2 ):
4349 super ().__init__ ()
@@ -151,7 +157,9 @@ def forward(self, h):
151157
152158@dataclass
153159class ProfileConfig :
154- file_path : Optional [str ] = None
160+ trace_file_path : Optional [str ] = None
161+ logs_file_path : Optional [str ] = None
162+ trace_modified_file_path : Optional [str ] = None
155163 name : Optional [str ] = None
156164 cuda : bool = True
157165 iters : int = 0
@@ -162,13 +170,33 @@ class ProfileConfig:
162170
163171
164172def profile_function (
165- config : ProfileConfig , func : Callable , * args , ** kwargs
173+ config : ProfileConfig ,
174+ func : Callable ,
175+ add_inductor_metadata_to_trace : bool ,
176+ * args ,
177+ ** kwargs ,
166178) -> torch .profiler .profile :
167179 """Profile a torch function and save the result to a file"""
168180 seed = 123
169181 random .seed (seed )
170182 torch .manual_seed (seed )
171183
184+ if add_inductor_metadata_to_trace :
185+ # ensure we aren't interfering with other torch_log settings
186+ if os .environ .get ('TORCH_LOGS' , '' ) != '' :
187+ raise AssertionError ('using TORCH_LOGS together with add_inductor_metadata_to_trace is not supported yet' )
188+
189+ # save torch.compile logs to a file specific to this benchmark run
190+ # TODO(future): can we hack torch.compile to print to file only and not stdout?
191+ # or maybe just use tlparse?
192+ torch ._logging .set_logs (output_code = True )
193+ # by default torch.compile appends to log_file_name, so we delete it
194+ # if it exists
195+ if os .path .isfile (config .logs_file_path ):
196+ pathlib .Path .unlink (config .logs_file_path )
197+ torch ._logging ._init_logs (log_file_name = config .logs_file_path )
198+
199+
172200 activities = [ProfilerActivity .CPU ]
173201 if config .cuda :
174202 activities .append (ProfilerActivity .CUDA )
@@ -182,6 +210,10 @@ def profile_function(
182210 nullcontext () if config .name is None else record_function (config .name )
183211 )
184212 profile_memory = config .memory_profile_path is not None
213+
214+ # warm up
215+ func (* args , ** kwargs )
216+
185217 with profile (
186218 activities = activities ,
187219 profile_memory = profile_memory ,
@@ -195,20 +227,35 @@ def profile_function(
195227 if config .sync :
196228 torch .cuda .synchronize ()
197229
198- if config .file_path is not None :
199- prof .export_chrome_trace (config .file_path )
230+ if config .trace_file_path is not None :
231+ prof .export_chrome_trace (config .trace_file_path )
232+
233+ if add_inductor_metadata_to_trace :
234+ # modify the trace to have the triton kernel metadata and code
235+ # visible inline
236+ update_triton_kernels_in_prof_chome_trace_with_torch_logs (
237+ config .trace_file_path ,
238+ config .logs_file_path ,
239+ config .trace_modified_file_path ,
240+ )
241+
242+ # undo custom log settings
243+ torch ._logging .set_logs (output_code = False )
244+ torch ._logging ._init_logs (log_file_name = None )
200245
201246 return prof
202247
203248
204249def main (
205- profile_path_prefix : Path ,
250+ profile_path_prefix : pathlib . Path ,
206251 compile : bool = True ,
207252 scaling_type_input : str = "dynamic" ,
208253 scaling_type_weight : str = "dynamic" ,
209254 scaling_type_grad_output : str = "dynamic" ,
210255 model_type : str = "linear" ,
211256 dtype_filter : str = "both" ,
257+ add_inductor_metadata_to_trace : bool = True ,
258+ enable_sync_amax_history : bool = True ,
212259):
213260 assert model_type in ("linear" , "ln_linear" , "norm_ffn_norm" , "norm_ffn_norm_small" ), "unsupported"
214261 assert dtype_filter in ("both" , "float8" , "bfloat16" )
@@ -220,6 +267,8 @@ def main(
220267 cast_config_input = CastConfig (scaling_type = scaling_type_input ),
221268 cast_config_weight = CastConfig (scaling_type = scaling_type_weight ),
222269 cast_config_grad_output = CastConfig (scaling_type = scaling_type_grad_output ),
270+ enable_amax_init = False ,
271+ enable_pre_and_post_forward = False ,
223272 )
224273 scaling_repr = "_" .join (
225274 [
@@ -290,7 +339,7 @@ def float8_forw_backward_wrapper(x):
290339 # inspection of the fw+bw torch.compile without the scale
291340 # syncing code
292341 # TODO(future): make this better
293- if linear_requires_sync (config ):
342+ if linear_requires_sync (config ) and enable_sync_amax_history :
294343 with record_function ("scale_amax_and_scales" ):
295344 sync_amax_history (m_float8 )
296345 out = float8_forw (x )
@@ -311,16 +360,14 @@ def float8_forw_backward_wrapper(x):
311360
312361 # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
313362 # to populate triton kernel bandwidth further down in the script
314- f = io .StringIO ()
363+ if os .environ .get ("TORCHINDUCTOR_PROFILE" , "" ) != "" :
364+ context = nullcontext ()
365+ f = None
366+ else :
367+ f = io .StringIO ()
368+ context = redirect_stdout (f )
315369 try :
316- with redirect_stdout (f ):
317- # warm up
318- for _ in range (1 ):
319- if dtype_filter != "float8" :
320- ref_forw_backward (input_tensor )
321- if dtype_filter != "bfloat16" :
322- float8_forw_backward_wrapper (input_tensor )
323-
370+ with context :
324371 profile_iters = 5
325372 ref_times , float8_times = None , None
326373 data = []
@@ -330,13 +377,19 @@ def float8_forw_backward_wrapper(x):
330377 if dtype_filter != "float8" :
331378 # Profile Reference Model
332379 print ("profiling ref" )
333- ref_suffix = f"_{ model_type } _ref_compile_{ compile } .json"
334- ref_path = profile_path_prefix + ref_suffix
380+ ref_trace_suffix = f"_{ model_type } _ref_compile_{ compile } .json"
381+ ref_logs_suffix = f"_{ model_type } _ref_compile_{ compile } .txt"
382+ trace_ref_path = profile_path_prefix + ref_trace_suffix
383+ log_ref_path = profile_path_prefix + ref_logs_suffix
384+ trace_ref_modified_path = trace_ref_path .replace (".json" , "_modified.json" )
335385 profile_config = ProfileConfig (
336- ref_path , ref_suffix , iters = profile_iters , warmup_iters = 2 , sync = True
386+ trace_ref_path , log_ref_path , trace_ref_modified_path , ref_trace_suffix , iters = profile_iters , warmup_iters = 2 , sync = True
337387 )
338- p = profile_function (profile_config , ref_forw_backward , input_tensor )
339- print (f"saved { ref_path } " )
388+ p = profile_function (profile_config , ref_forw_backward , add_inductor_metadata_to_trace , input_tensor )
389+ print (f"saved profiling trace to { trace_ref_path } " )
390+ if add_inductor_metadata_to_trace :
391+ print (f"saved torch logs to { log_ref_path } " )
392+ print (f"saved modified trace to { trace_ref_modified_path } " )
340393 ref_times = profiler_output_to_filtered_time_by_kernel_name (p , profile_iters , num_leaf_tensors )
341394 total_time_ms = sum (v for v in ref_times .values ()) / 1e3 / profile_iters
342395 for k , v in ref_times .items ():
@@ -355,21 +408,31 @@ def float8_forw_backward_wrapper(x):
355408 if dtype_filter != "bfloat16" :
356409 # Profile Float8 Model
357410 print ("profiling float8" )
358- float8_suffix = (
411+ float8_trace_suffix = (
359412 f"_{ model_type } _float8_compile_{ compile } _{ scaling_repr } .json"
360413 )
361- float8_path = profile_path_prefix + float8_suffix
414+ float8_log_suffix = (
415+ f"_{ model_type } _float8_compile_{ compile } _{ scaling_repr } .txt"
416+ )
417+ trace_float8_path = profile_path_prefix + float8_trace_suffix
418+ log_float8_path = profile_path_prefix + float8_log_suffix
419+ trace_float8_modified_path = trace_float8_path .replace (".json" , "_modified.json" )
362420 profile_config = ProfileConfig (
363- float8_path ,
364- float8_suffix ,
421+ trace_float8_path ,
422+ log_float8_path ,
423+ trace_float8_modified_path ,
424+ float8_trace_suffix ,
365425 iters = profile_iters ,
366426 warmup_iters = 2 ,
367427 sync = True ,
368428 )
369429 p = profile_function (
370- profile_config , float8_forw_backward_wrapper , input_tensor
430+ profile_config , float8_forw_backward_wrapper , add_inductor_metadata_to_trace , input_tensor
371431 )
372- print (f"saved { float8_path } " )
432+ print (f"saved profiling trace to { trace_float8_path } " )
433+ if add_inductor_metadata_to_trace :
434+ print (f"saved torch logs to { log_float8_path } " )
435+ print (f"saved modified trace to { trace_float8_modified_path } " )
373436 float8_times = profiler_output_to_filtered_time_by_kernel_name (p , profile_iters , num_leaf_tensors )
374437 total_time_ms = sum (v for v in float8_times .values ()) / 1e3 / profile_iters
375438 for k , v in float8_times .items ():
@@ -393,17 +456,19 @@ def float8_forw_backward_wrapper(x):
393456 print (f"Sync time ms: { sync_time_ms } " )
394457
395458 finally :
396- # print the redirected stdout back to regular stdout
397- print (f .getvalue ())
398-
399- # populate the triton kernel bandwidth
400- for line in f .getvalue ().split ("\n " ):
401- maybe_bw , maybe_kernel_name = parse_bw_and_kernel_name (line )
402- if maybe_kernel_name is not None :
403- # O(N) search, but it's ok since lists are small
404- for datum in data :
405- if datum [1 ] == maybe_kernel_name :
406- datum [- 1 ] = maybe_bw
459+ if f is not None :
460+ # print the redirected stdout back to regular stdout
461+ print (f .getvalue ())
462+
463+ if os .environ .get ("TORCHINDUCTOR_PROFILE" , "" ) != "" :
464+ # populate the triton kernel bandwidth
465+ for line in f .getvalue ().split ("\n " ):
466+ maybe_bw , maybe_kernel_name = parse_bw_and_kernel_name (line )
467+ if maybe_kernel_name is not None :
468+ # O(N) search, but it's ok since lists are small
469+ for datum in data :
470+ if datum [1 ] == maybe_kernel_name :
471+ datum [- 1 ] = maybe_bw
407472
408473 df = pd .DataFrame (
409474 data ,
0 commit comments