33import os
44import pprint
55import sys
6- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
6+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
77import time
88import numpy as np
99import onnx
@@ -994,6 +994,26 @@ def _validate_do_run_exported_program(data, summary, verbose, quiet):
994994 )
995995
996996
997+ _cache_export_times = []
998+ _main_export_function = torch .export .export
999+
1000+
1001+ def _torch_export_export (* args , _export = _main_export_function , ** kwargs ):
1002+ begin = time .perf_counter ()
1003+ res = _export (* args , ** kwargs )
1004+ duration = time .perf_counter () - begin
1005+ _cache_export_times .append (duration )
1006+ return res
1007+
1008+
1009+ def _restore_torch_export_export (summary ):
1010+ torch .export .export = _main_export_function
1011+ if _cache_export_times :
1012+ summary ["time_torch_export_export" ] = sum (_cache_export_times )
1013+ summary ["time_torch_export_export_n" ] = len (_cache_export_times )
1014+ _cache_export_times .clear ()
1015+
1016+
9971017def call_exporter (
9981018 data : Dict [str , Any ],
9991019 exporter : str ,
@@ -1019,6 +1039,9 @@ def call_exporter(
10191039 :return: two dictionaries, one with some metrics,
10201040 another one with whatever the function produces
10211041 """
1042+ _cache_export_times .clear ()
1043+ torch .export .export = _torch_export_export
1044+
10221045 if exporter == "export" or exporter .startswith ("export-" ):
10231046 # torch export
10241047 summary , data = call_torch_export_export (
@@ -1029,6 +1052,7 @@ def call_exporter(
10291052 optimization = optimization ,
10301053 do_run = do_run ,
10311054 )
1055+ _restore_torch_export_export (summary )
10321056 return summary , data
10331057 if exporter .startswith ("onnx-" ):
10341058 # torch export
@@ -1040,6 +1064,7 @@ def call_exporter(
10401064 optimization = optimization ,
10411065 output_names = output_names ,
10421066 )
1067+ _restore_torch_export_export (summary )
10431068 return summary , data
10441069 if exporter == "custom" or exporter .startswith ("custom" ):
10451070 # torch export
@@ -1052,6 +1077,7 @@ def call_exporter(
10521077 dump_folder = dump_folder ,
10531078 output_names = output_names ,
10541079 )
1080+ _restore_torch_export_export (summary )
10551081 return summary , data
10561082 if exporter == "modelbuilder" :
10571083 # torch export
@@ -1063,6 +1089,7 @@ def call_exporter(
10631089 optimization = optimization ,
10641090 output_names = output_names ,
10651091 )
1092+ _restore_torch_export_export (summary )
10661093 return summary , data
10671094 raise NotImplementedError (
10681095 f"export with { exporter !r} and optimization={ optimization !r} not implemented yet, "
@@ -1634,6 +1661,97 @@ def call_torch_export_model_builder(
16341661 return summary , data
16351662
16361663
1664+ def process_statistics (data : Sequence [Dict [str , float ]]) -> Dict [str , Any ]:
1665+ """
1666+ Processes statistics coming from the exporters.
1667+ It takes a sequence of dictionaries (like a data frame)
1668+ and extracts some metrics.
1669+ """
1670+
1671+ def _simplify (p ):
1672+ for s in [
1673+ "remove_unused" ,
1674+ "constant_folding" ,
1675+ "remove_identity" ,
1676+ "remove_duplicated_initializer" ,
1677+ "dynamic_dimension_naming" ,
1678+ "inline" ,
1679+ "check" ,
1680+ "build_graph_for_pattern" ,
1681+ "pattern_optimization" ,
1682+ ]:
1683+ if s in p or s .replace ("_" , "-" ) in p :
1684+ return s
1685+ if p .startswith (("apply_" , "match_" )):
1686+ return p
1687+ return "other"
1688+
1689+ def _add (d , a , v , use_max = False ):
1690+ if v :
1691+ if a not in d :
1692+ d [a ] = v
1693+ elif use_max :
1694+ d [a ] = max (d [a ], v )
1695+ else :
1696+ d [a ] += v
1697+
1698+ counts : Dict [str , Any ] = {}
1699+ applied_pattern_time : Dict [str , Any ] = {}
1700+ applied_pattern_n : Dict [str , Any ] = {}
1701+ matching_pattern_time : Dict [str , Any ] = {}
1702+ matching_pattern_n : Dict [str , Any ] = {}
1703+
1704+ for obs in data :
1705+ pattern = _simplify (obs ["pattern" ])
1706+ _add (counts , "opt_nodes_added" , obs .get ("added" , 0 ))
1707+ _add (counts , "opt_nodes_removed" , obs .get ("removed" , 0 ))
1708+ _add (counts , "opt_time_steps" , obs .get ("time_in" , 0 ))
1709+ _add (counts , "opt_n_steps" , 1 )
1710+ _add (
1711+ counts ,
1712+ "opt_n_iteration" ,
1713+ max (counts .get ("opt_n_iteration" , 0 ), obs .get ("iteration" , 0 )),
1714+ use_max = True ,
1715+ )
1716+
1717+ if pattern .startswith ("apply_" ):
1718+ _add (counts , "opt_n_applied_patterns" , 1 )
1719+ _add (counts , "opt_time_applied_patterns" , obs .get ("time_in" , 0 ))
1720+ _add (applied_pattern_time , pattern , obs .get ("time_in" , 0 ))
1721+ _add (applied_pattern_n , pattern , 1 )
1722+ elif pattern .startswith ("match_" ):
1723+ _add (counts , "opt_n_matching_patterns" , 1 )
1724+ _add (counts , "opt_time_matching_patterns" , obs .get ("time_in" , 0 ))
1725+ _add (matching_pattern_time , pattern , obs .get ("time_in" , 0 ))
1726+ _add (matching_pattern_n , pattern , 1 )
1727+ else :
1728+ _add (counts , f"opt_time_{ pattern } " , obs .get ("time_in" , 0 ))
1729+ _add (counts , f"opt_n_{ pattern } " , 1 )
1730+ _add (counts , f"opt_nodes_added_{ pattern } " , obs .get ("added" , 0 ))
1731+ _add (counts , f"opt_nodes_removed_{ pattern } " , obs .get ("removed" , 0 ))
1732+
1733+ if applied_pattern_time :
1734+ longest = max ((v , k ) for k , v in applied_pattern_time .items ())
1735+ counts ["opt_top_time_applied_pattern" ], counts ["opt_top_time_applied_pattern_arg" ] = (
1736+ longest
1737+ )
1738+ longest = max ((v , k ) for k , v in applied_pattern_n .items ())
1739+ counts ["opt_top_n_applied_pattern" ], counts ["opt_top_n_applied_pattern_arg" ] = longest
1740+
1741+ if matching_pattern_time :
1742+ longest = max ((v , k ) for k , v in matching_pattern_time .items ())
1743+ (
1744+ counts ["opt_top_time_matching_pattern" ],
1745+ counts ["opt_top_time_matching_pattern_arg" ],
1746+ ) = longest
1747+ longest = max ((v , k ) for k , v in matching_pattern_n .items ())
1748+ counts ["opt_top_n_matching_pattern" ], counts ["opt_top_n_matching_pattern_arg" ] = (
1749+ longest
1750+ )
1751+ counts ["onnx_opt_optimized" ] = 1
1752+ return counts
1753+
1754+
16371755def call_torch_export_custom (
16381756 data : Dict [str , Any ],
16391757 exporter : str ,
@@ -1763,67 +1881,10 @@ def call_torch_export_custom(
17631881 if "ERR_export_onnx_c" in summary :
17641882 return summary , data
17651883
1766- new_stat = {}
1884+ new_stat : Dict [str , Any ] = {k : v for k , v in opt_stats .items () if k .startswith ("time_" )}
1885+ new_stat .update ({k [5 :]: v for k , v in opt_stats .items () if k .startswith ("stat_time_" )})
17671886 if "optimization" in opt_stats :
1768- added , removed , time_in = 0 , 0 , 0.0
1769- max_iter = 0
1770- applied = {}
1771- matched = set ()
1772- n_applied = 0
1773- by_pattern = {}
1774- by_pattern_n = {}
1775- by_iter = {}
1776- cst_added , cst_removed , cst_time_in = 0 , 0 , 0.0
1777-
1778- for obs in opt_stats ["optimization" ]:
1779- pattern = obs ["pattern" ]
1780- if pattern == "constant_folding" :
1781- cst_added += obs .get ("added" , 0 )
1782- cst_removed += obs .get ("removed" , 0 )
1783- cst_time_in += obs .get ("time_in" , 0 )
1784- if pattern not in by_pattern :
1785- by_pattern [pattern ] = 0
1786- by_pattern_n [pattern ] = 0
1787- by_iter [pattern ] = 0
1788- time_in += obs .get ("time_in" , 0 )
1789- added += obs .get ("added" , 0 )
1790- removed += obs .get ("removed" , 0 )
1791- max_iter = max (max_iter , obs .get ("iteration" , 0 ))
1792- by_pattern [pattern ] += obs .get ("time_in" , 0 )
1793- by_pattern_n [pattern ] += obs .get ("added" , 0 ) - obs .get ("removed" , 0 )
1794- if not pattern .startswith ("match" ):
1795- by_iter [pattern ] = max (by_iter [pattern ], obs .get ("iteration" , 0 ))
1796- p = obs ["pattern" ]
1797- if p .startswith ("match_" ):
1798- matched .add (p )
1799- elif p .startswith ("apply_" ):
1800- key = f"op_opt_{ p } "
1801- key2 = f"op_opt_maxiter_{ p } "
1802- if key not in applied :
1803- applied [key ] = 1
1804- applied [key2 ] = obs ["iteration" ]
1805- else :
1806- applied [key ] += 1
1807- applied [key2 ] = max (obs ["iteration" ], applied [key2 ])
1808- n_applied += 1
1809-
1810- new_stat .update (
1811- dict (
1812- onnx_opt_optimized = 1 ,
1813- op_opt_all_time_in = time_in ,
1814- op_opt_all_added = added ,
1815- op_opt_all_removed = removed ,
1816- op_opt_max_iter = max_iter ,
1817- op_opt_unique_matched = len (matched ),
1818- op_opt_unique_applied = len (applied ),
1819- op_opt_n_applied = n_applied ,
1820- time_export_optimization = time_in ,
1821- op_opt_export_optimization = time_in ,
1822- op_opt_cst_time_in = cst_time_in ,
1823- op_opt_cst_added = cst_added ,
1824- op_opt_cst_removed = cst_removed ,
1825- )
1826- )
1887+ new_stat .update (process_statistics (opt_stats ["optimization" ]))
18271888
18281889 summary .update (new_stat )
18291890 assert epo is not None , "no onnx export was found"
0 commit comments