Skip to content

Commit 94f18e8

Browse files
briancoutinhofacebook-github-bot
authored andcommitted
Add a trace validator helper and simple unit test for
Summary: ## Summary * Adds a trace validation tool that can check PyTorch host execution traces. This is helpful for schema changes, and integration testing * Add a unit test to check if execution_trace.py works correctly on preset traces. * Minor: helpers to read semantic version of pytorch and chakra! Reviewed By: shengbao-zheng Differential Revision: D56325885
1 parent 7868e09 commit 94f18e8

File tree

5 files changed

+164
-2
lines changed

5 files changed

+164
-2
lines changed

train/comms/pt/commsTraceParser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ def _parseExecutionTrace(
217217
218218
"""
219219
# Execution Trace PG_ID types availability
220-
ET_PG_NAME_TUPLE = True if in_trace.schema == "1.0.3-chakra.0.0.4" else False
221-
ET_BACKENDID = True if in_trace.schema != "1.0.3-chakra.0.0.4" else False
220+
ET_PG_NAME_TUPLE = True if in_trace.schema_pytorch() >= (1, 0, 3) else False
221+
ET_BACKENDID = True if in_trace.schema_pytorch() >= (1, 0, 3) else False
222222

223223
initOps = []
224224
newCommsTrace = []
Binary file not shown.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import gzip
2+
import json
3+
import os
4+
import unittest
5+
6+
from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace
7+
from param_bench.train.compute.python.tools.validate_trace import TraceValidator
8+
9+
CURR_DIR = os.path.dirname(os.path.realpath(__file__))
10+
11+
12+
class TestTraceLoadAndValidate(unittest.TestCase):
13+
def setUp(self):
14+
self.trace_base = os.path.join(CURR_DIR, "data")
15+
16+
def _test_and_validate_trace(self, trace_file):
17+
with (
18+
gzip.open(trace_file, "rb")
19+
if trace_file.endswith("gz")
20+
else open(trace_file, "r")
21+
) as execution_data:
22+
execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data))
23+
t = TraceValidator(execution_trace)
24+
self.assertTrue(t.validate())
25+
return t, execution_trace
26+
27+
def test_trace_load_resnet_1gpu(self):
28+
et_file = os.path.join(
29+
self.trace_base, "1.0.3-chakra.0.0.4/resnet_1gpu_et.json.gz"
30+
)
31+
t, et = self._test_and_validate_trace(et_file)
32+
self.assertGreater(t.num_ops(), 1000)
33+
self.assertEqual(t.num_comm_ops(), 12)
34+
self.assertEqual(t.num_triton_ops(), 0)
35+
36+
37+
if __name__ == "__main__":
38+
unittest.main()

train/compute/python/tools/execution_trace.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,17 @@ def __init__(self, json):
350350
# remove all dataloader ops
351351
self.remove_dataloader_ops()
352352

353+
def _versiontuple(self, v: str) -> Tuple[int]:
354+
return tuple(map(int, (v.split("."))))
355+
356+
def schema_pytorch(self) -> Tuple[int]:
357+
return self._versiontuple(self.schema.split("-")[0])
358+
359+
def schema_chakra(self) -> Tuple[int]:
360+
if "-" not in self.schema:
361+
return (0, 0, 0)
362+
return self._versiontuple(self.schema.split("-")[1])
363+
353364
@staticmethod
354365
def _read_attrs(node: Dict[str, Any]) -> Tuple:
355366
attr_types = {
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import (
2+
absolute_import,
3+
annotations,
4+
division,
5+
print_function,
6+
unicode_literals,
7+
)
8+
9+
import gzip
10+
import json
11+
12+
from .execution_trace import ExecutionTrace
13+
14+
15+
class TraceValidator:
16+
17+
def __init__(self, execution_trace: ExecutionTrace):
18+
self.et = execution_trace
19+
20+
def _ops(self):
21+
return (n for n in self.et.nodes.values() if n.is_op())
22+
23+
def _validate_ops(self) -> bool:
24+
"""Make sure the pytorch operators are valid"""
25+
ops = self._ops()
26+
for op in ops:
27+
if op.name == "":
28+
print(f"op should have valid name, node id = {op.id}")
29+
30+
# if len(list(op.get_outputs())) + len(list(op.get_inputs())) == 0:
31+
# print(f"op should have outputs or inputs, node = {op.name}")
32+
# FIXME see "autograd::engine::evaluate_function: DivBackward1"
33+
# currently let's skip this
34+
# return False
35+
return True
36+
37+
def _validate_tree(self) -> bool:
38+
"""TBD validate that the generated datastructure is a tree
39+
with parent/child relationship. We can use pydot or networkx libs for this
40+
"""
41+
return True
42+
43+
def _validate_param_comms(self) -> bool:
44+
"""Check if param comms has correct attributes"""
45+
# This should use the comms parser, for now something simple will be fine
46+
# https://github.com/facebookresearch/param/blob/main/train/comms/pt/commsTraceParser.py#L256
47+
48+
if self.et.schema_pytorch() < (1, 0, 2):
49+
return True
50+
51+
def check_comms_node(n) -> bool:
52+
"""TODO use comms parser"""
53+
has_pg_id = False
54+
# Slightly hacky but find a argument with tuple type
55+
for arg in n.get_inputs():
56+
if arg[0] == "Tuple[String,String]":
57+
print(f" {n.name}, process group args = {arg}")
58+
has_pg_id = True
59+
return has_pg_id
60+
61+
return all(
62+
check_comms_node(n)
63+
for n in self.et.nodes.values()
64+
if n.is_op() and n.name == "record_param_comms"
65+
)
66+
67+
def _validate_triton(self) -> bool:
68+
"""Make sure triton kernels have correct values
69+
TODO update for checking if kernel files are captured.
70+
"""
71+
return True
72+
73+
def validate(self) -> bool:
74+
return all(
75+
[
76+
self._validate_ops(),
77+
self._validate_tree(),
78+
self._validate_param_comms(),
79+
self._validate_triton(),
80+
]
81+
)
82+
83+
def num_ops(self) -> int:
84+
return len(list(self._ops()))
85+
86+
def num_comm_ops(self) -> int:
87+
return sum(1 for op in self._ops() if op.name == "record_param_comms")
88+
89+
def num_triton_ops(self) -> int:
90+
return sum(1 for op in self._ops() if "triton" in op.name)
91+
92+
93+
def main():
94+
import sys
95+
96+
execution_json = sys.argv[1]
97+
98+
with (
99+
gzip.open(execution_json, "rb")
100+
if execution_json.endswith("gz")
101+
else open(execution_json, "r")
102+
) as execution_data:
103+
execution_trace: ExecutionTrace = ExecutionTrace(json.load(execution_data))
104+
t = TraceValidator(execution_trace)
105+
print(
106+
f"num ops = {t.num_ops()}, num comms = {t.num_comm_ops()}, "
107+
f"num triton ops = {t.num_triton_ops()}"
108+
)
109+
print("Trace validation result = ", t.validate())
110+
111+
112+
if __name__ == "__main__":
113+
main()

0 commit comments

Comments
 (0)