diff --git a/dlrm_data_pytorch.py b/dlrm_data_pytorch.py index b815205c..95a74d59 100644 --- a/dlrm_data_pytorch.py +++ b/dlrm_data_pytorch.py @@ -442,7 +442,7 @@ def make_criteo_data_and_loaders(args, offset_to_length_converter=False): ) mlperf_logger.log_event(key=mlperf_logger.constants.TRAIN_SAMPLES, - value=train_data.num_samples) + value=train_data.num_entries) train_loader = torch.utils.data.DataLoader( train_data, @@ -464,7 +464,7 @@ def make_criteo_data_and_loaders(args, offset_to_length_converter=False): ) mlperf_logger.log_event(key=mlperf_logger.constants.EVAL_SAMPLES, - value=test_data.num_samples) + value=test_data.num_entries) test_loader = torch.utils.data.DataLoader( test_data,