Skip to content

Commit c8b1aa5

Browse files
authored
Merge pull request #88 from CentML/johncalesp/add-pl
[Pytorch Lightning support] added pytorch lightning
2 parents 5d4fd07 + 346e099 commit c8b1aa5

File tree

6 files changed

+460
-10
lines changed

6 files changed

+460
-10
lines changed

deepview_profile/export_converter.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import json
2+
3+
4+
def convert(message):
5+
new_message = {}
6+
with open("message.json", "w") as fp:
7+
json.dump(message, fp, indent=4)
8+
9+
new_message["ddp"] = {}
10+
new_message["message_type"] = message["message_type"]
11+
new_message["project_root"] = message["project_root"]
12+
new_message["project_entry_point"] = message["project_entry_point"]
13+
14+
new_message["hardware_info"] = {
15+
"hostname": message["hardware_info"]["hostname"],
16+
"os": message["hardware_info"]["os"],
17+
"gpus": message["hardware_info"]["gpus"],
18+
}
19+
20+
new_message["throughput"] = {
21+
"samples_per_second": message["throughput"]["samples_per_second"],
22+
"predicted_max_samples_per_second": message["throughput"][
23+
"predicted_max_samples_per_second"
24+
],
25+
"run_time_ms": (
26+
[
27+
message["throughput"]["run_time_ms"]["slope"],
28+
message["throughput"]["run_time_ms"]["bias"],
29+
]
30+
if "run_time_ms" in message["throughput"]
31+
else [0, 0]
32+
),
33+
"peak_usage_bytes": (
34+
[
35+
message["throughput"]["peak_usage_bytes"]["slope"],
36+
message["throughput"]["peak_usage_bytes"]["bias"],
37+
]
38+
if "peak_usage_bytes" in message["throughput"]
39+
else [0, 0]
40+
),
41+
"batch_size_context": None,
42+
"can_manipulate_batch_size": False,
43+
}
44+
45+
new_message["utilization"] = message["utilization"]
46+
47+
def fix(a):
48+
for d in ["cpu", "gpu"]:
49+
for s in ["Forward", "Backward"]:
50+
if f"{d}_{s.lower()}" in a:
51+
a[f"{d}{s}"] = a[f"{d}_{s.lower()}"]
52+
del a[f"{d}_{s.lower()}"]
53+
else:
54+
a[f"{d}{s}"] = 0
55+
56+
if f"{d}_{s.lower()}_span" in a:
57+
a[f"{d}{s}Span"] = a[f"{d}_{s.lower()}_span"]
58+
del a[f"{d}_{s.lower()}_span"]
59+
else:
60+
a[f"{d}{s}Span"] = 0
61+
62+
if "children" not in a:
63+
a["children"] = []
64+
return
65+
66+
if a:
67+
for c in a["children"]:
68+
fix(c)
69+
70+
(
71+
fix(new_message["utilization"]["rootNode"])
72+
if new_message["utilization"].get("rootNode", None)
73+
else None
74+
)
75+
try:
76+
new_message["utilization"]["tensor_core_usage"] = message["utilization"][
77+
"tensor_utilization"
78+
]
79+
except:
80+
new_message["utilization"]["tensor_core_usage"] = 0
81+
82+
new_message["habitat"] = {
83+
"predictions": [
84+
(
85+
[prediction["device_name"], prediction["runtime_ms"]]
86+
if prediction["device_name"] != "unavailable"
87+
else ["default_device", 0]
88+
)
89+
for prediction in message["habitat"]["predictions"]
90+
]
91+
}
92+
93+
new_message["breakdown"] = {
94+
"peak_usage_bytes": int(message["breakdown"]["peak_usage_bytes"]),
95+
"memory_capacity_bytes": int(message["breakdown"]["memory_capacity_bytes"]),
96+
"iteration_run_time_ms": message["breakdown"]["iteration_run_time_ms"],
97+
# TODO change these hardcoded numbers
98+
"batch_size": 48,
99+
"num_nodes_operation_tree": len(message["breakdown"]["operation_tree"]),
100+
"num_nodes_weight_tree": 0,
101+
"operation_tree": [
102+
{
103+
"name": op["name"],
104+
"num_children": op["num_children"] if "num_children" in op else 0,
105+
"forward_ms": op["operation"]["forward_ms"],
106+
"backward_ms": op["operation"]["backward_ms"],
107+
"size_bytes": (
108+
int(op["operation"]["size_bytes"])
109+
if "size_bytes" in op["operation"]
110+
else 0
111+
),
112+
"file_refs": (
113+
[
114+
{
115+
"path": "/".join(ctx["context"]["file_path"]["components"]),
116+
"line_no": ctx["context"]["line_number"],
117+
"run_time_ms": ctx["run_time_ms"],
118+
"size_bytes": (
119+
int(ctx["size_bytes"]) if "size_bytes" in ctx else 0
120+
),
121+
}
122+
for ctx in op["operation"]["context_info_map"]
123+
]
124+
if "context_info_map" in op["operation"]
125+
else list()
126+
),
127+
}
128+
for op in message["breakdown"]["operation_tree"]
129+
],
130+
}
131+
132+
def fix_components(m):
133+
for c in m["components"]:
134+
if "consumption_joules" not in c:
135+
c["consumption"] = 0
136+
else:
137+
c["consumption"] = c["consumption_joules"]
138+
del c["consumption_joules"]
139+
c["type"] = c["component_type"]
140+
if c["type"] == "ENERGY_NVIDIA":
141+
c["type"] = "ENERGY_GPU"
142+
del c["component_type"]
143+
144+
new_message["energy"] = {
145+
"current": {
146+
"total_consumption": message["energy"]["total_consumption"],
147+
"components": message["energy"]["components"],
148+
"batch_size": 48,
149+
},
150+
"past_measurements": message["energy"].get("past_measurements", None),
151+
}
152+
153+
fix_components(new_message["energy"]["current"])
154+
if new_message["energy"].get("past_measurements", None):
155+
for m in new_message["energy"]["past_measurements"]:
156+
fix_components(m)
157+
158+
return new_message
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import Callable, Tuple
2+
3+
import time
4+
import os
5+
import json
6+
import torch
7+
import sys
8+
9+
try:
10+
import pytorch_lightning as pl
11+
except ImportError:
12+
sys.exit("Please install pytorch-lightning:\nuse: pip install lightning\nExiting...")
13+
14+
from termcolor import colored
15+
from deepview_profile.pl.deepview_interface import trigger_profiling
16+
17+
18+
class DeepViewProfilerCallback(pl.Callback):
19+
def __init__(self, profile_name: str):
20+
super().__init__()
21+
self.profiling_triggered = False
22+
self.output_filename = f"{profile_name}_{int(time.time())}.json"
23+
24+
def on_train_batch_end(
25+
self,
26+
trainer: pl.Trainer,
27+
pl_module: pl.LightningModule,
28+
outputs,
29+
batch,
30+
batch_idx,
31+
):
32+
33+
# only do this once
34+
if self.profiling_triggered:
35+
return
36+
37+
print(colored("DeepViewProfiler: Running profiling.", "green"))
38+
39+
"""
40+
need 3 things:
41+
42+
input_provider: just return batch
43+
model_provider: just return pl_module
44+
iteration_provider: a lambda function that (a) calls pl_module.forward_step and (b) calls loss.backward
45+
"""
46+
initial_batch_size = batch[0].shape[0]
47+
48+
def input_provider(batch_size: int = initial_batch_size) -> Tuple:
49+
model_inputs = list()
50+
for elem in batch:
51+
# we assume the first dimension is the batch dimension
52+
model_inputs.append(
53+
elem[:1].repeat([batch_size] + [1 for _ in elem.shape[1:]])
54+
)
55+
return (tuple(model_inputs), 0)
56+
57+
model_provider = lambda: pl_module
58+
59+
def iteration_provider(module: torch.nn.Module) -> Callable:
60+
def iteration(*args, **kwargs):
61+
loss = module.training_step(*args, **kwargs)
62+
loss.backward()
63+
64+
return iteration
65+
66+
project_root = os.getcwd()
67+
68+
output = trigger_profiling(
69+
project_root,
70+
"entry_point.py",
71+
initial_batch_size,
72+
input_provider,
73+
model_provider,
74+
iteration_provider,
75+
)
76+
77+
with open(self.output_filename, "w") as fp:
78+
json.dump(output, fp, indent=4)
79+
80+
print(
81+
colored(
82+
f"DeepViewProfiler: Profiling complete! Report written to ", "green"
83+
)
84+
+ colored(self.output_filename, "green", attrs=["bold"])
85+
)
86+
print(
87+
colored(
88+
f"DeepViewProfiler: View your report at https://deepview.centml.ai",
89+
"green",
90+
)
91+
)
92+
self.profiling_triggered = True
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import sys
2+
from typing import Callable
3+
import platform
4+
5+
from deepview_profile.analysis.session import AnalysisSession
6+
from deepview_profile.exceptions import AnalysisError
7+
from deepview_profile.nvml import NVML
8+
9+
# from deepview_profile.utils import release_memory, next_message_to_dict, files_encoded_unique
10+
from deepview_profile.utils import release_memory, files_encoded_unique
11+
from deepview_profile.error_printing import print_analysis_error
12+
13+
from google.protobuf.json_format import MessageToDict
14+
15+
16+
def measure_breakdown(session, nvml):
17+
print("analysis: running measure_breakdown()")
18+
yield session.measure_breakdown(nvml)
19+
release_memory()
20+
21+
22+
def measure_throughput(session):
23+
print("analysis: running measure_throughput()")
24+
yield session.measure_throughput()
25+
release_memory()
26+
27+
28+
def habitat_predict(session):
29+
print("analysis: running deepview_predict()")
30+
yield session.habitat_predict()
31+
release_memory()
32+
33+
34+
def measure_utilization(session):
35+
print("analysis: running measure_utilization()")
36+
yield session.measure_utilization()
37+
release_memory()
38+
39+
40+
def energy_compute(session):
41+
print("analysis: running energy_compute()")
42+
yield session.energy_compute()
43+
release_memory()
44+
45+
46+
def ddp_analysis(session):
47+
print("analysis: running ddp_computation()")
48+
yield session.ddp_computation()
49+
release_memory()
50+
51+
52+
def hardware_information(nvml):
53+
hardware_info = {
54+
"hostname": platform.node(),
55+
"os": " ".join(list(platform.uname())),
56+
"gpus": nvml.get_device_names(),
57+
}
58+
return hardware_info
59+
60+
61+
class DummyStaticAnalyzer:
62+
def batch_size_location(self):
63+
return None
64+
65+
66+
def next_message_to_dict(a):
67+
message = next(a)
68+
return MessageToDict(message, preserving_proto_field_name=True)
69+
70+
71+
def trigger_profiling(
72+
project_root: str,
73+
entry_point: str,
74+
initial_batch_size: int,
75+
input_provider: Callable,
76+
model_provider: Callable,
77+
iteration_provider: Callable,
78+
):
79+
try:
80+
data = {
81+
"analysis": {
82+
"message_type": "analysis",
83+
"project_root": project_root,
84+
"project_entry_point": entry_point,
85+
"hardware_info": {},
86+
"throughput": {},
87+
"breakdown": {},
88+
"habitat": {},
89+
"additionalProviders": "",
90+
"energy": {},
91+
"utilization": {},
92+
"ddp": {},
93+
},
94+
"epochs": 50,
95+
"iterations": 1000,
96+
"encodedFiles": [],
97+
}
98+
99+
session = AnalysisSession(
100+
project_root,
101+
entry_point,
102+
project_root,
103+
model_provider,
104+
input_provider,
105+
iteration_provider,
106+
initial_batch_size,
107+
DummyStaticAnalyzer(),
108+
)
109+
release_memory()
110+
111+
exclude_source = False
112+
113+
with NVML() as nvml:
114+
data["analysis"]["hardware_info"] = hardware_information(nvml)
115+
data["analysis"]["breakdown"] = next_message_to_dict(
116+
measure_breakdown(session, nvml)
117+
)
118+
119+
operation_tree = data["analysis"]["breakdown"]["operation_tree"]
120+
if not exclude_source and operation_tree is not None:
121+
data["encodedFiles"] = files_encoded_unique(operation_tree)
122+
123+
data["analysis"]["throughput"] = next_message_to_dict(
124+
measure_throughput(session)
125+
)
126+
data["analysis"]["habitat"] = next_message_to_dict(habitat_predict(session))
127+
data["analysis"]["utilization"] = next_message_to_dict(
128+
measure_utilization(session)
129+
)
130+
data["analysis"]["energy"] = next_message_to_dict(energy_compute(session))
131+
# data['analysis']['ddp'] = next_message_to_dict(ddp_analysis(session))
132+
133+
from deepview_profile.export_converter import convert
134+
135+
data["analysis"] = convert(data["analysis"])
136+
137+
return data
138+
139+
except AnalysisError as ex:
140+
print_analysis_error(ex)
141+
sys.exit(1)

0 commit comments

Comments
 (0)