Skip to content

Commit f8ee546

Browse files
committed
Remove sanity check
1 parent 8306a04 commit f8ee546

File tree

2 files changed

+70
-243
lines changed

2 files changed

+70
-243
lines changed

fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

Lines changed: 50 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
# pyre-strict
99

10-
import gzip
11-
import yaml
10+
1211
import logging
1312
import os
1413
import tempfile
@@ -1012,15 +1011,7 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
10121011
@TbeBenchClickInterface.common_options
10131012
@TbeBenchClickInterface.device_options
10141013
@TbeBenchClickInterface.vbe_options
1015-
@click.option("--save", type=str, default=None)
1016-
@click.option("--load", type=str, default=None)
1017-
@click.option("--random-weights", is_flag=True, default=False)
1018-
@click.option("--compressed", is_flag=True, default=False)
1019-
@click.option("--slice-min", type=int, default=None)
1020-
@click.option("--slice-max", type=int, default=None)
1021-
@click.pass_context
10221014
def device_with_spec( # noqa C901
1023-
ctx,
10241015
alpha: float,
10251016
bag_size_list: str,
10261017
bag_size_sigma_list: str,
@@ -1040,39 +1031,7 @@ def device_with_spec( # noqa C901
10401031
bounds_check_mode: int,
10411032
flush_gpu_cache_size_mb: int,
10421033
output_dtype: SparseType,
1043-
save: str,
1044-
load: str,
1045-
random_weights: bool,
1046-
compressed: bool,
1047-
slice_min: int,
1048-
slice_max: int,
10491034
) -> None:
1050-
if load:
1051-
with open(f"{load}/params.yaml", "r") as f:
1052-
ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader)
1053-
alpha = ctx.params["alpha"]
1054-
bag_size_list = ctx.params["bag_size_list"]
1055-
bag_size_sigma_list = ctx.params["bag_size_sigma_list"]
1056-
batch_size = ctx.params["batch_size"]
1057-
embedding_dim_list = ctx.params["embedding_dim_list"]
1058-
weights_precision = ctx.params["weights_precision"]
1059-
cache_precision = ctx.params["cache_precision"]
1060-
stoc = ctx.params["stoc"]
1061-
iters = ctx.params["iters"]
1062-
warmup_runs = ctx.params["warmup_runs"]
1063-
managed = ctx.params["managed"]
1064-
num_embeddings_list = ctx.params["num_embeddings_list"]
1065-
reuse = ctx.params["reuse"]
1066-
row_wise = ctx.params["row_wise"]
1067-
weighted = ctx.params["weighted"]
1068-
pooling = ctx.params["pooling"]
1069-
bounds_check_mode = ctx.params["bounds_check_mode"]
1070-
flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"]
1071-
output_dtype = ctx.params["output_dtype"]
1072-
random_weights = ctx.params["random_weights"]
1073-
compressed = ctx.params["compressed"]
1074-
slice_min = ctx.params["slice_min"]
1075-
slice_max = ctx.params["slice_max"]
10761035
np.random.seed(42)
10771036
torch.manual_seed(42)
10781037
B = batch_size
@@ -1081,11 +1040,6 @@ def device_with_spec( # noqa C901
10811040
T = len(Ds)
10821041

10831042
use_variable_bag_sizes = bag_size_sigma_list != "None"
1084-
params = ctx.params
1085-
if save:
1086-
os.makedirs(f"{save}", exist_ok=True)
1087-
with open(f"{save}/params.yaml", "w") as f:
1088-
yaml.dump(params, f, sort_keys=False)
10891043

10901044
if use_variable_bag_sizes:
10911045
Ls = [int(mu) for mu in bag_size_list.split(",")]
@@ -1164,22 +1118,6 @@ def device_with_spec( # noqa C901
11641118

11651119
if weights_precision == SparseType.INT8:
11661120
emb.init_embedding_weights_uniform(-0.0003, 0.0003)
1167-
elif random_weights:
1168-
emb.init_embedding_weights_uniform(-1.0, 1.0)
1169-
1170-
if save:
1171-
if compressed:
1172-
with gzip.open(f"{save}/model_state.pth.gz", "wb") as f:
1173-
torch.save(emb.state_dict(), f)
1174-
else:
1175-
torch.save(emb.state_dict(), f"{save}/model_state.pth")
1176-
1177-
if load:
1178-
if compressed:
1179-
with gzip.open(f"{load}/model_state.pth.gz", "rb") as f:
1180-
emb.load_state_dict(torch.load(f))
1181-
else:
1182-
emb.load_state_dict(torch.load(f"{load}/model_state.pth"))
11831121

11841122
nparams = sum(w.numel() for w in emb.split_embedding_weights())
11851123
param_size_multiplier = weights_precision.bit_rate() / 8.0
@@ -1192,66 +1130,53 @@ def device_with_spec( # noqa C901
11921130
"weights": [[] for _ in range(iters)],
11931131
}
11941132
# row = iter, column = tensor
1195-
if load:
1196-
requests = []
1197-
for i in range(iters):
1198-
indices = torch.load(f"{load}/{i}_indices.pt")
1199-
offsets = torch.load(f"{load}/{i}_offsets.pt")
1200-
per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt")
1201-
Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt")
1202-
requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank))
1203-
else:
1204-
for t, e in enumerate(Es):
1205-
# (indices, offsets, weights)
1206-
requests = generate_requests(
1207-
iters,
1208-
B,
1209-
1,
1210-
Ls[t],
1211-
e,
1212-
reuse=reuse,
1213-
alpha=alpha,
1214-
weighted=weighted,
1215-
# pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined.
1216-
sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None,
1217-
zipf_oversample_ratio=3 if Ls[t] > 5 else 5,
1218-
use_cpu=get_available_compute_device() == ComputeDevice.CPU,
1219-
index_dtype=torch.long,
1220-
offset_dtype=torch.long,
1221-
)
1222-
for i, req in enumerate(requests):
1223-
indices, offsets, weights = req.unpack_3()
1224-
all_requests["indices"][i].append(indices)
1225-
if t > 0:
1226-
offsets = offsets[1:] # remove the first element
1227-
offsets += all_requests["offsets"][i][t - 1][-1]
1228-
all_requests["offsets"][i].append(offsets)
1229-
all_requests["weights"][i].append(weights)
1230-
1231-
prev_indices_len = -1
1232-
requests = []
1233-
for i in range(iters):
1234-
indices = torch.concat(all_requests["indices"][i])
1235-
if prev_indices_len == -1:
1236-
prev_indices_len = indices.numel()
1237-
assert (
1238-
prev_indices_len == indices.numel()
1239-
), "Number of indices for every iteration must be the same"
1240-
offsets = torch.concat(all_requests["offsets"][i])
1241-
if weighted:
1242-
weights = torch.concat(all_requests["weights"][i])
1243-
else:
1244-
weights = None
1245-
requests.append(TBERequest(indices, offsets, weights))
1246-
del all_requests
1133+
for t, e in enumerate(Es):
1134+
# (indices, offsets, weights)
1135+
requests = generate_requests(
1136+
iters,
1137+
B,
1138+
1,
1139+
Ls[t],
1140+
e,
1141+
reuse=reuse,
1142+
alpha=alpha,
1143+
weighted=weighted,
1144+
# pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined.
1145+
sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None,
1146+
zipf_oversample_ratio=3 if Ls[t] > 5 else 5,
1147+
use_cpu=get_available_compute_device() == ComputeDevice.CPU,
1148+
index_dtype=torch.long,
1149+
offset_dtype=torch.long,
1150+
)
1151+
for i, req in enumerate(requests):
1152+
indices, offsets, weights = req.unpack_3()
1153+
all_requests["indices"][i].append(indices)
1154+
if t > 0:
1155+
offsets = offsets[1:] # remove the first element
1156+
offsets += all_requests["offsets"][i][t - 1][-1]
1157+
all_requests["offsets"][i].append(offsets)
1158+
all_requests["weights"][i].append(weights)
1159+
1160+
prev_indices_len = -1
1161+
requests = []
1162+
for i in range(iters):
1163+
indices = torch.concat(all_requests["indices"][i])
1164+
if prev_indices_len == -1:
1165+
prev_indices_len = indices.numel()
1166+
assert (
1167+
prev_indices_len == indices.numel()
1168+
), "Number of indices for every iteration must be the same"
1169+
offsets = torch.concat(all_requests["offsets"][i])
1170+
if weighted:
1171+
weights = torch.concat(all_requests["weights"][i])
1172+
else:
1173+
weights = None
1174+
requests.append(TBERequest(indices, offsets, weights))
1175+
1176+
del all_requests
1177+
12471178
assert len(requests) == iters
1248-
if save:
1249-
for i in range(iters):
1250-
req = requests[i]
1251-
torch.save(req.indices, f"{save}/{i}_indices.pt")
1252-
torch.save(req.offsets, f"{save}/{i}_offsets.pt")
1253-
torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt")
1254-
torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt")
1179+
12551180
sum_DLs = sum([d * l for d, l in zip(Ds, Ls)])
12561181
if do_pooling:
12571182
read_write_bytes = (
@@ -1299,22 +1224,13 @@ def device_with_spec( # noqa C901
12991224
# backward bench not representative
13001225
return
13011226

1302-
if load:
1303-
grad_output = torch.load(f"{load}/grad_output.pt")
1227+
if do_pooling:
1228+
grad_output = torch.randn(B, sum(Ds)).to(get_device())
13041229
else:
13051230
# Obtain B * L from indices len
13061231
# pyre-ignore[19]
13071232
# pyre-fixme[61]: `D` is undefined, or not always defined.
1308-
if do_pooling:
1309-
grad_output = torch.randn(B, sum(Ds)).to(get_device())
1310-
else:
1311-
# Obtain B * L from indices len
1312-
# pyre-ignore[19]
1313-
# pyre-fixme[61]: `D` is undefined, or not always defined.
1314-
grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device())
1315-
1316-
if save:
1317-
torch.save(grad_output, f"{save}/grad_output.pt")
1233+
grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device())
13181234
# backward
13191235
time_per_iter = benchmark_requests(
13201236
requests,
@@ -1328,12 +1244,6 @@ def device_with_spec( # noqa C901
13281244
bwd_only=True,
13291245
grad=grad_output,
13301246
num_warmups=warmup_runs,
1331-
emb=emb,
1332-
save=save,
1333-
load=load,
1334-
compressed=compressed,
1335-
slice_min=slice_min,
1336-
slice_max=slice_max,
13371247
)
13381248
logging.info(
13391249
f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, "

0 commit comments

Comments
 (0)