77
88# pyre-strict
99
10- import gzip
11- import yaml
10+
1211import logging
1312import os
1413import tempfile
@@ -1012,15 +1011,7 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
10121011@TbeBenchClickInterface .common_options
10131012@TbeBenchClickInterface .device_options
10141013@TbeBenchClickInterface .vbe_options
1015- @click .option ("--save" , type = str , default = None )
1016- @click .option ("--load" , type = str , default = None )
1017- @click .option ("--random-weights" , is_flag = True , default = False )
1018- @click .option ("--compressed" , is_flag = True , default = False )
1019- @click .option ("--slice-min" , type = int , default = None )
1020- @click .option ("--slice-max" , type = int , default = None )
1021- @click .pass_context
10221014def device_with_spec ( # noqa C901
1023- ctx ,
10241015 alpha : float ,
10251016 bag_size_list : str ,
10261017 bag_size_sigma_list : str ,
@@ -1040,39 +1031,7 @@ def device_with_spec( # noqa C901
10401031 bounds_check_mode : int ,
10411032 flush_gpu_cache_size_mb : int ,
10421033 output_dtype : SparseType ,
1043- save : str ,
1044- load : str ,
1045- random_weights : bool ,
1046- compressed : bool ,
1047- slice_min : int ,
1048- slice_max : int ,
10491034) -> None :
1050- if load :
1051- with open (f"{ load } /params.yaml" , "r" ) as f :
1052- ctx .params = yaml .load (f , Loader = yaml .UnsafeLoader )
1053- alpha = ctx .params ["alpha" ]
1054- bag_size_list = ctx .params ["bag_size_list" ]
1055- bag_size_sigma_list = ctx .params ["bag_size_sigma_list" ]
1056- batch_size = ctx .params ["batch_size" ]
1057- embedding_dim_list = ctx .params ["embedding_dim_list" ]
1058- weights_precision = ctx .params ["weights_precision" ]
1059- cache_precision = ctx .params ["cache_precision" ]
1060- stoc = ctx .params ["stoc" ]
1061- iters = ctx .params ["iters" ]
1062- warmup_runs = ctx .params ["warmup_runs" ]
1063- managed = ctx .params ["managed" ]
1064- num_embeddings_list = ctx .params ["num_embeddings_list" ]
1065- reuse = ctx .params ["reuse" ]
1066- row_wise = ctx .params ["row_wise" ]
1067- weighted = ctx .params ["weighted" ]
1068- pooling = ctx .params ["pooling" ]
1069- bounds_check_mode = ctx .params ["bounds_check_mode" ]
1070- flush_gpu_cache_size_mb = ctx .params ["flush_gpu_cache_size_mb" ]
1071- output_dtype = ctx .params ["output_dtype" ]
1072- random_weights = ctx .params ["random_weights" ]
1073- compressed = ctx .params ["compressed" ]
1074- slice_min = ctx .params ["slice_min" ]
1075- slice_max = ctx .params ["slice_max" ]
10761035 np .random .seed (42 )
10771036 torch .manual_seed (42 )
10781037 B = batch_size
@@ -1081,11 +1040,6 @@ def device_with_spec( # noqa C901
10811040 T = len (Ds )
10821041
10831042 use_variable_bag_sizes = bag_size_sigma_list != "None"
1084- params = ctx .params
1085- if save :
1086- os .makedirs (f"{ save } " , exist_ok = True )
1087- with open (f"{ save } /params.yaml" , "w" ) as f :
1088- yaml .dump (params , f , sort_keys = False )
10891043
10901044 if use_variable_bag_sizes :
10911045 Ls = [int (mu ) for mu in bag_size_list .split ("," )]
@@ -1164,22 +1118,6 @@ def device_with_spec( # noqa C901
11641118
11651119 if weights_precision == SparseType .INT8 :
11661120 emb .init_embedding_weights_uniform (- 0.0003 , 0.0003 )
1167- elif random_weights :
1168- emb .init_embedding_weights_uniform (- 1.0 , 1.0 )
1169-
1170- if save :
1171- if compressed :
1172- with gzip .open (f"{ save } /model_state.pth.gz" , "wb" ) as f :
1173- torch .save (emb .state_dict (), f )
1174- else :
1175- torch .save (emb .state_dict (), f"{ save } /model_state.pth" )
1176-
1177- if load :
1178- if compressed :
1179- with gzip .open (f"{ load } /model_state.pth.gz" , "rb" ) as f :
1180- emb .load_state_dict (torch .load (f ))
1181- else :
1182- emb .load_state_dict (torch .load (f"{ load } /model_state.pth" ))
11831121
11841122 nparams = sum (w .numel () for w in emb .split_embedding_weights ())
11851123 param_size_multiplier = weights_precision .bit_rate () / 8.0
@@ -1192,66 +1130,53 @@ def device_with_spec( # noqa C901
11921130 "weights" : [[] for _ in range (iters )],
11931131 }
11941132 # row = iter, column = tensor
1195- if load :
1196- requests = []
1197- for i in range (iters ):
1198- indices = torch .load (f"{ load } /{ i } _indices.pt" )
1199- offsets = torch .load (f"{ load } /{ i } _offsets.pt" )
1200- per_sample_weights = torch .load (f"{ load } /{ i } _per_sample_weights.pt" )
1201- Bs_per_feature_per_rank = torch .load (f"{ load } /{ i } _Bs_per_feature_per_rank.pt" )
1202- requests .append (TBERequest (indices , offsets , per_sample_weights , Bs_per_feature_per_rank ))
1203- else :
1204- for t , e in enumerate (Es ):
1205- # (indices, offsets, weights)
1206- requests = generate_requests (
1207- iters ,
1208- B ,
1209- 1 ,
1210- Ls [t ],
1211- e ,
1212- reuse = reuse ,
1213- alpha = alpha ,
1214- weighted = weighted ,
1215- # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined.
1216- sigma_L = sigma_Ls [t ] if use_variable_bag_sizes else None ,
1217- zipf_oversample_ratio = 3 if Ls [t ] > 5 else 5 ,
1218- use_cpu = get_available_compute_device () == ComputeDevice .CPU ,
1219- index_dtype = torch .long ,
1220- offset_dtype = torch .long ,
1221- )
1222- for i , req in enumerate (requests ):
1223- indices , offsets , weights = req .unpack_3 ()
1224- all_requests ["indices" ][i ].append (indices )
1225- if t > 0 :
1226- offsets = offsets [1 :] # remove the first element
1227- offsets += all_requests ["offsets" ][i ][t - 1 ][- 1 ]
1228- all_requests ["offsets" ][i ].append (offsets )
1229- all_requests ["weights" ][i ].append (weights )
1230-
1231- prev_indices_len = - 1
1232- requests = []
1233- for i in range (iters ):
1234- indices = torch .concat (all_requests ["indices" ][i ])
1235- if prev_indices_len == - 1 :
1236- prev_indices_len = indices .numel ()
1237- assert (
1238- prev_indices_len == indices .numel ()
1239- ), "Number of indices for every iteration must be the same"
1240- offsets = torch .concat (all_requests ["offsets" ][i ])
1241- if weighted :
1242- weights = torch .concat (all_requests ["weights" ][i ])
1243- else :
1244- weights = None
1245- requests .append (TBERequest (indices , offsets , weights ))
1246- del all_requests
1133+ for t , e in enumerate (Es ):
1134+ # (indices, offsets, weights)
1135+ requests = generate_requests (
1136+ iters ,
1137+ B ,
1138+ 1 ,
1139+ Ls [t ],
1140+ e ,
1141+ reuse = reuse ,
1142+ alpha = alpha ,
1143+ weighted = weighted ,
1144+ # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined.
1145+ sigma_L = sigma_Ls [t ] if use_variable_bag_sizes else None ,
1146+ zipf_oversample_ratio = 3 if Ls [t ] > 5 else 5 ,
1147+ use_cpu = get_available_compute_device () == ComputeDevice .CPU ,
1148+ index_dtype = torch .long ,
1149+ offset_dtype = torch .long ,
1150+ )
1151+ for i , req in enumerate (requests ):
1152+ indices , offsets , weights = req .unpack_3 ()
1153+ all_requests ["indices" ][i ].append (indices )
1154+ if t > 0 :
1155+ offsets = offsets [1 :] # remove the first element
1156+ offsets += all_requests ["offsets" ][i ][t - 1 ][- 1 ]
1157+ all_requests ["offsets" ][i ].append (offsets )
1158+ all_requests ["weights" ][i ].append (weights )
1159+
1160+ prev_indices_len = - 1
1161+ requests = []
1162+ for i in range (iters ):
1163+ indices = torch .concat (all_requests ["indices" ][i ])
1164+ if prev_indices_len == - 1 :
1165+ prev_indices_len = indices .numel ()
1166+ assert (
1167+ prev_indices_len == indices .numel ()
1168+ ), "Number of indices for every iteration must be the same"
1169+ offsets = torch .concat (all_requests ["offsets" ][i ])
1170+ if weighted :
1171+ weights = torch .concat (all_requests ["weights" ][i ])
1172+ else :
1173+ weights = None
1174+ requests .append (TBERequest (indices , offsets , weights ))
1175+
1176+ del all_requests
1177+
12471178 assert len (requests ) == iters
1248- if save :
1249- for i in range (iters ):
1250- req = requests [i ]
1251- torch .save (req .indices , f"{ save } /{ i } _indices.pt" )
1252- torch .save (req .offsets , f"{ save } /{ i } _offsets.pt" )
1253- torch .save (req .per_sample_weights , f"{ save } /{ i } _per_sample_weights.pt" )
1254- torch .save (req .Bs_per_feature_per_rank , f"{ save } /{ i } _Bs_per_feature_per_rank.pt" )
1179+
12551180 sum_DLs = sum ([d * l for d , l in zip (Ds , Ls )])
12561181 if do_pooling :
12571182 read_write_bytes = (
@@ -1299,22 +1224,13 @@ def device_with_spec( # noqa C901
12991224 # backward bench not representative
13001225 return
13011226
1302- if load :
1303- grad_output = torch .load ( f" { load } /grad_output.pt" )
1227+ if do_pooling :
1228+ grad_output = torch .randn ( B , sum ( Ds )). to ( get_device () )
13041229 else :
13051230 # Obtain B * L from indices len
13061231 # pyre-ignore[19]
13071232 # pyre-fixme[61]: `D` is undefined, or not always defined.
1308- if do_pooling :
1309- grad_output = torch .randn (B , sum (Ds )).to (get_device ())
1310- else :
1311- # Obtain B * L from indices len
1312- # pyre-ignore[19]
1313- # pyre-fixme[61]: `D` is undefined, or not always defined.
1314- grad_output = torch .randn (requests [0 ].indices .numel (), D ).to (get_device ())
1315-
1316- if save :
1317- torch .save (grad_output , f"{ save } /grad_output.pt" )
1233+ grad_output = torch .randn (requests [0 ].indices .numel (), D ).to (get_device ())
13181234 # backward
13191235 time_per_iter = benchmark_requests (
13201236 requests ,
@@ -1328,12 +1244,6 @@ def device_with_spec( # noqa C901
13281244 bwd_only = True ,
13291245 grad = grad_output ,
13301246 num_warmups = warmup_runs ,
1331- emb = emb ,
1332- save = save ,
1333- load = load ,
1334- compressed = compressed ,
1335- slice_min = slice_min ,
1336- slice_max = slice_max ,
13371247 )
13381248 logging .info (
13391249 f"Backward, B: { B } , Es: { Es } , T: { T } , Ds: { Ds } , Ls: { Ls_str } , "
0 commit comments