Skip to content

Commit 4866853

Browse files
authored
Log the time for torch.export.export (#235)
* more about export time * types * fix max
1 parent 54afb59 commit 4866853

File tree

1 file changed

+122
-61
lines changed

1 file changed

+122
-61
lines changed

onnx_diagnostic/torch_models/validate.py

Lines changed: 122 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import pprint
55
import sys
6-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
77
import time
88
import numpy as np
99
import onnx
@@ -994,6 +994,26 @@ def _validate_do_run_exported_program(data, summary, verbose, quiet):
994994
)
995995

996996

997+
_cache_export_times = []
998+
_main_export_function = torch.export.export
999+
1000+
1001+
def _torch_export_export(*args, _export=_main_export_function, **kwargs):
1002+
begin = time.perf_counter()
1003+
res = _export(*args, **kwargs)
1004+
duration = time.perf_counter() - begin
1005+
_cache_export_times.append(duration)
1006+
return res
1007+
1008+
1009+
def _restore_torch_export_export(summary):
1010+
torch.export.export = _main_export_function
1011+
if _cache_export_times:
1012+
summary["time_torch_export_export"] = sum(_cache_export_times)
1013+
summary["time_torch_export_export_n"] = len(_cache_export_times)
1014+
_cache_export_times.clear()
1015+
1016+
9971017
def call_exporter(
9981018
data: Dict[str, Any],
9991019
exporter: str,
@@ -1019,6 +1039,9 @@ def call_exporter(
10191039
:return: two dictionaries, one with some metrics,
10201040
another one with whatever the function produces
10211041
"""
1042+
_cache_export_times.clear()
1043+
torch.export.export = _torch_export_export
1044+
10221045
if exporter == "export" or exporter.startswith("export-"):
10231046
# torch export
10241047
summary, data = call_torch_export_export(
@@ -1029,6 +1052,7 @@ def call_exporter(
10291052
optimization=optimization,
10301053
do_run=do_run,
10311054
)
1055+
_restore_torch_export_export(summary)
10321056
return summary, data
10331057
if exporter.startswith("onnx-"):
10341058
# torch export
@@ -1040,6 +1064,7 @@ def call_exporter(
10401064
optimization=optimization,
10411065
output_names=output_names,
10421066
)
1067+
_restore_torch_export_export(summary)
10431068
return summary, data
10441069
if exporter == "custom" or exporter.startswith("custom"):
10451070
# torch export
@@ -1052,6 +1077,7 @@ def call_exporter(
10521077
dump_folder=dump_folder,
10531078
output_names=output_names,
10541079
)
1080+
_restore_torch_export_export(summary)
10551081
return summary, data
10561082
if exporter == "modelbuilder":
10571083
# torch export
@@ -1063,6 +1089,7 @@ def call_exporter(
10631089
optimization=optimization,
10641090
output_names=output_names,
10651091
)
1092+
_restore_torch_export_export(summary)
10661093
return summary, data
10671094
raise NotImplementedError(
10681095
f"export with {exporter!r} and optimization={optimization!r} not implemented yet, "
@@ -1634,6 +1661,97 @@ def call_torch_export_model_builder(
16341661
return summary, data
16351662

16361663

1664+
def process_statistics(data: Sequence[Dict[str, float]]) -> Dict[str, Any]:
1665+
"""
1666+
Processes statistics coming from the exporters.
1667+
It takes a sequence of dictionaries (like a data frame)
1668+
and extracts some metrics.
1669+
"""
1670+
1671+
def _simplify(p):
1672+
for s in [
1673+
"remove_unused",
1674+
"constant_folding",
1675+
"remove_identity",
1676+
"remove_duplicated_initializer",
1677+
"dynamic_dimension_naming",
1678+
"inline",
1679+
"check",
1680+
"build_graph_for_pattern",
1681+
"pattern_optimization",
1682+
]:
1683+
if s in p or s.replace("_", "-") in p:
1684+
return s
1685+
if p.startswith(("apply_", "match_")):
1686+
return p
1687+
return "other"
1688+
1689+
def _add(d, a, v, use_max=False):
1690+
if v:
1691+
if a not in d:
1692+
d[a] = v
1693+
elif use_max:
1694+
d[a] = max(d[a], v)
1695+
else:
1696+
d[a] += v
1697+
1698+
counts: Dict[str, Any] = {}
1699+
applied_pattern_time: Dict[str, Any] = {}
1700+
applied_pattern_n: Dict[str, Any] = {}
1701+
matching_pattern_time: Dict[str, Any] = {}
1702+
matching_pattern_n: Dict[str, Any] = {}
1703+
1704+
for obs in data:
1705+
pattern = _simplify(obs["pattern"])
1706+
_add(counts, "opt_nodes_added", obs.get("added", 0))
1707+
_add(counts, "opt_nodes_removed", obs.get("removed", 0))
1708+
_add(counts, "opt_time_steps", obs.get("time_in", 0))
1709+
_add(counts, "opt_n_steps", 1)
1710+
_add(
1711+
counts,
1712+
"opt_n_iteration",
1713+
max(counts.get("opt_n_iteration", 0), obs.get("iteration", 0)),
1714+
use_max=True,
1715+
)
1716+
1717+
if pattern.startswith("apply_"):
1718+
_add(counts, "opt_n_applied_patterns", 1)
1719+
_add(counts, "opt_time_applied_patterns", obs.get("time_in", 0))
1720+
_add(applied_pattern_time, pattern, obs.get("time_in", 0))
1721+
_add(applied_pattern_n, pattern, 1)
1722+
elif pattern.startswith("match_"):
1723+
_add(counts, "opt_n_matching_patterns", 1)
1724+
_add(counts, "opt_time_matching_patterns", obs.get("time_in", 0))
1725+
_add(matching_pattern_time, pattern, obs.get("time_in", 0))
1726+
_add(matching_pattern_n, pattern, 1)
1727+
else:
1728+
_add(counts, f"opt_time_{pattern}", obs.get("time_in", 0))
1729+
_add(counts, f"opt_n_{pattern}", 1)
1730+
_add(counts, f"opt_nodes_added_{pattern}", obs.get("added", 0))
1731+
_add(counts, f"opt_nodes_removed_{pattern}", obs.get("removed", 0))
1732+
1733+
if applied_pattern_time:
1734+
longest = max((v, k) for k, v in applied_pattern_time.items())
1735+
counts["opt_top_time_applied_pattern"], counts["opt_top_time_applied_pattern_arg"] = (
1736+
longest
1737+
)
1738+
longest = max((v, k) for k, v in applied_pattern_n.items())
1739+
counts["opt_top_n_applied_pattern"], counts["opt_top_n_applied_pattern_arg"] = longest
1740+
1741+
if matching_pattern_time:
1742+
longest = max((v, k) for k, v in matching_pattern_time.items())
1743+
(
1744+
counts["opt_top_time_matching_pattern"],
1745+
counts["opt_top_time_matching_pattern_arg"],
1746+
) = longest
1747+
longest = max((v, k) for k, v in matching_pattern_n.items())
1748+
counts["opt_top_n_matching_pattern"], counts["opt_top_n_matching_pattern_arg"] = (
1749+
longest
1750+
)
1751+
counts["onnx_opt_optimized"] = 1
1752+
return counts
1753+
1754+
16371755
def call_torch_export_custom(
16381756
data: Dict[str, Any],
16391757
exporter: str,
@@ -1763,67 +1881,10 @@ def call_torch_export_custom(
17631881
if "ERR_export_onnx_c" in summary:
17641882
return summary, data
17651883

1766-
new_stat = {}
1884+
new_stat: Dict[str, Any] = {k: v for k, v in opt_stats.items() if k.startswith("time_")}
1885+
new_stat.update({k[5:]: v for k, v in opt_stats.items() if k.startswith("stat_time_")})
17671886
if "optimization" in opt_stats:
1768-
added, removed, time_in = 0, 0, 0.0
1769-
max_iter = 0
1770-
applied = {}
1771-
matched = set()
1772-
n_applied = 0
1773-
by_pattern = {}
1774-
by_pattern_n = {}
1775-
by_iter = {}
1776-
cst_added, cst_removed, cst_time_in = 0, 0, 0.0
1777-
1778-
for obs in opt_stats["optimization"]:
1779-
pattern = obs["pattern"]
1780-
if pattern == "constant_folding":
1781-
cst_added += obs.get("added", 0)
1782-
cst_removed += obs.get("removed", 0)
1783-
cst_time_in += obs.get("time_in", 0)
1784-
if pattern not in by_pattern:
1785-
by_pattern[pattern] = 0
1786-
by_pattern_n[pattern] = 0
1787-
by_iter[pattern] = 0
1788-
time_in += obs.get("time_in", 0)
1789-
added += obs.get("added", 0)
1790-
removed += obs.get("removed", 0)
1791-
max_iter = max(max_iter, obs.get("iteration", 0))
1792-
by_pattern[pattern] += obs.get("time_in", 0)
1793-
by_pattern_n[pattern] += obs.get("added", 0) - obs.get("removed", 0)
1794-
if not pattern.startswith("match"):
1795-
by_iter[pattern] = max(by_iter[pattern], obs.get("iteration", 0))
1796-
p = obs["pattern"]
1797-
if p.startswith("match_"):
1798-
matched.add(p)
1799-
elif p.startswith("apply_"):
1800-
key = f"op_opt_{p}"
1801-
key2 = f"op_opt_maxiter_{p}"
1802-
if key not in applied:
1803-
applied[key] = 1
1804-
applied[key2] = obs["iteration"]
1805-
else:
1806-
applied[key] += 1
1807-
applied[key2] = max(obs["iteration"], applied[key2])
1808-
n_applied += 1
1809-
1810-
new_stat.update(
1811-
dict(
1812-
onnx_opt_optimized=1,
1813-
op_opt_all_time_in=time_in,
1814-
op_opt_all_added=added,
1815-
op_opt_all_removed=removed,
1816-
op_opt_max_iter=max_iter,
1817-
op_opt_unique_matched=len(matched),
1818-
op_opt_unique_applied=len(applied),
1819-
op_opt_n_applied=n_applied,
1820-
time_export_optimization=time_in,
1821-
op_opt_export_optimization=time_in,
1822-
op_opt_cst_time_in=cst_time_in,
1823-
op_opt_cst_added=cst_added,
1824-
op_opt_cst_removed=cst_removed,
1825-
)
1826-
)
1887+
new_stat.update(process_statistics(opt_stats["optimization"]))
18271888

18281889
summary.update(new_stat)
18291890
assert epo is not None, "no onnx export was found"

0 commit comments

Comments
 (0)