Skip to content

Commit 455d44d

Browse files
authored
[MLP IMPROVEMENTS] Add new sampling for bmm recordings (#57)
* predict linear using kernel time * include __matmul__ in bmm predictor * linear model mixed dataset * reverted to test all models * reshape __matmul__ to fit into bmm * test reshape of __matmul__ args * new sampling strategy for bmm * additional verifications for bmm * cap memory for bmm sampling * readjust bmm mem ceil * trained bmm mlp model with new data * deleted temporary debug print * restore rest of experiments * fix bmm max memory consumption * add more L4 samples to train bmm mlp --------- Co-authored-by: John Calderon <[email protected]>
1 parent 08a280a commit 455d44d

File tree

5 files changed

+48
-8
lines changed

5 files changed

+48
-8
lines changed

analyzer/habitat/analysis/predictor.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import operator
44
import numpy as np
5+
import math
56

67
from habitat.analysis import SPECIAL_OPERATIONS
78
from habitat.analysis.operation import PredictedOperation
@@ -116,9 +117,9 @@ def predict_operation(self, operation, dest_device, unscaled=False):
116117
return self._special_scale(operation, dest_device, self._conv2d_scale, unscaled)
117118
elif operation.name == 'lstm':
118119
return self._special_scale(operation, dest_device, self._lstm_scale, unscaled)
119-
elif operation.name in ['linear','__matmul__']:
120+
elif operation.name == 'linear':
120121
return self._special_scale(operation, dest_device, self._linear_scale, unscaled)
121-
elif operation.name == 'bmm':
122+
elif operation.name in ['bmm', '__matmul__']:
122123
return self._special_scale(operation, dest_device, self._bmm_scale, unscaled)
123124
elif operation.name == 'conv_transpose2d':
124125
return self._special_scale(operation, dest_device, self._conv_transpose2d_scale, unscaled)
@@ -284,6 +285,7 @@ def _linear_scale(self, operation, dest_device, unscaled=False):
284285
arguments = [arguments[x] for x in self.linear_pred.model.features]
285286

286287
pred_dest = self.linear_pred.predict(arguments, dest_device.name)
288+
287289
pred_orig = self.linear_pred.predict(arguments, operation.device.name)
288290

289291
if unscaled:
@@ -295,18 +297,30 @@ def _linear_scale(self, operation, dest_device, unscaled=False):
295297
return operation.run_time_ms * pred_dest / pred_orig
296298

297299
def _bmm_scale(self, operation, dest_device, unscaled=False):
300+
# nn.Linear may call __matmul__ which in turn calls bmm
301+
# but the shape of the arguments may be [a,b,c,d].
302+
# So we need to reshape them into [a*b,c,d]
303+
reshape_args = []
304+
for arg in operation.arguments.args:
305+
if len(arg) > 3:
306+
reshape_args.append([math.prod(arg[:-2]),arg[-2], arg[-1]])
307+
else:
308+
reshape_args.append(arg)
309+
operation.arguments.args = reshape_args
310+
298311
merged = name_all_arguments(
299312
BMM_PARAMS,
300313
operation.arguments.args,
301314
operation.arguments.kwargs,
302315
)
303-
316+
304317
arguments = dict(
305318
batch=merged['input'][0],
306319
left=merged['input'][1],
307320
middle=merged['input'][2],
308321
right=merged['mat2'][2],
309322
)
323+
310324
arguments = [arguments[x] for x in self.bmm_pred.model.features]
311325

312326
pred_dest = self.bmm_pred.predict(arguments, dest_device.name)

analyzer/habitat/data/bmm/model.pth

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:628cd9ecca8cda59e0b5277580c996a72bae9b29bf3c5bdabccd9dfa6fc34389
2+
oid sha256:70c172469e8c1244e7fb53444e324bc8ac2cd4f8552e86a7e2d3444c02f43128
33
size 33634474
870 KB
Binary file not shown.

tools/recording/parameter_generator.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
import sys
55
import random
66
from typing import Dict, List
7+
import psutil
78

9+
# SET CEIL FOR AVAILABLE RAM (avoid running-out-mem for sampling bmm)
10+
CURR_MEM = psutil.virtual_memory()[1]
11+
BMM_MEM_CEIL = int(0.9 * CURR_MEM)
812

913
class main_generator:
10-
"Special distribution for conv2d and linear records"
14+
"Special distribution for conv2d, bmm, batch_norm, and linear"
1115

1216
def __init__(self, ops):
1317

@@ -16,6 +20,10 @@ def __init__(self, ops):
1620

1721
if ops == "conv2d" or ops == "batch_norm":
1822
filename = "conv2d_sampled_params.pkl"
23+
24+
elif ops == "bmm":
25+
filename = "bmm_sampled_params.pkl"
26+
1927
elif ops == "linear":
2028
filename = "linear_sampled_params.pkl"
2129

@@ -25,7 +33,7 @@ def __init__(self, ops):
2533
param_dict: Dict[str, int] = dict()
2634
dist_arr: List[List[int, int]] = []
2735

28-
if ops == "conv2d" or ops == "batch_norm":
36+
if ops in ["conv2d", "bmm", "batch_norm"]:
2937
# weight by model count
3038
model_counts: Dict[str, int] = dict()
3139
for row in data:
@@ -73,7 +81,7 @@ def generate_sample(self):
7381
]
7482
if round_sample[2] != 0 and round_sample[3] != 0:
7583
return round_sample
76-
84+
7785
elif self._ops == "batch_norm":
7886
round_sample = [
7987
self.round(sample[0][0]), # in_channels
@@ -85,6 +93,24 @@ def generate_sample(self):
8593
if round_sample[1] != 0:
8694
return [round_sample[1]]
8795

96+
elif self._ops == "bmm":
97+
round_sample = [
98+
self.round(sample[0][0]), # bs
99+
self.round(sample[1][0]), # left
100+
self.round(sample[2][0]), # middle
101+
self.round(sample[3][0]), # right
102+
]
103+
# validate non-zeros
104+
# check if available memory (RuntimeError DefaultCPUAllocator: can't allocate memory)
105+
# 4 for FP32
106+
matrix_a_size = 4 * round_sample[0] * round_sample[1] * round_sample[2]
107+
matrix_b_size = 4 * round_sample[0] * round_sample[2] * round_sample[3]
108+
if (
109+
np.all(round_sample)
110+
and matrix_a_size + matrix_b_size < BMM_MEM_CEIL
111+
):
112+
return round_sample
113+
88114
elif self._ops == "linear":
89115
in_features = self.round(sample[0][0])
90116
out_features = self.round(sample[1][0])

tools/recording/record_common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Some operators such as conv2d and linear need to be sampled from a different distribution (gaussian + uniform)
1919
main_generator generates these new samples
2020
"""
21-
SPECIAL_SAMPLING_OPS = ['conv2d','linear', 'batch_norm']
21+
SPECIAL_SAMPLING_OPS = ['conv2d','linear', 'batch_norm', 'bmm']
2222

2323
class Measurer:
2424
def __init__(

0 commit comments

Comments
 (0)