diff --git a/pytorch_translate/data/data.py b/pytorch_translate/data/data.py index 03eba1fd..ef0fb189 100644 --- a/pytorch_translate/data/data.py +++ b/pytorch_translate/data/data.py @@ -261,7 +261,11 @@ def create_from_file(path, is_npz=True, num_examples_limit: Optional[int] = None return result else: # idx, bin format - return InMemoryIndexedDataset(path) + impl = data.indexed_dataset.infer_dataset_impl(path) + if impl == "mmap": + return data.indexed_dataset.MMapIndexedDataset(path) + else: + return InMemoryIndexedDataset(path) def subsample(self, indices): """ diff --git a/pytorch_translate/options.py b/pytorch_translate/options.py index c20cfe44..c3769b98 100644 --- a/pytorch_translate/options.py +++ b/pytorch_translate/options.py @@ -293,6 +293,12 @@ def add_preprocessing_args(parser): default="", help="Path for the binary file containing target side monolingual data", ) + group.add_argument( + "--fairseq-binary-data-format", + default=False, + action="store_true", + help="Binary data paths are prefixes of .bin and .idx files", + ) # TODO(T43045193): Move this to multilingual_task.py eventually group.add_argument( diff --git a/pytorch_translate/preprocess.py b/pytorch_translate/preprocess.py index 6667b27e..031ae6d3 100644 --- a/pytorch_translate/preprocess.py +++ b/pytorch_translate/preprocess.py @@ -15,7 +15,7 @@ from pytorch_translate.data.dictionary import Dictionary -def maybe_generate_temp_file_path(output_path=None): +def maybe_generate_temp_file_path(output_path=None, is_npz=True): """ This function generates a temp file path if output_path is empty or None. This is useful to do before calling any preprocessing function that has a @@ -28,7 +28,7 @@ def maybe_generate_temp_file_path(output_path=None): os.close(fd) # numpy silently appends this suffix if it is not present, so this ensures # that the correct path is returned - if not output_path.endswith(".npz"): + if is_npz and not output_path.endswith(".npz"): output_path += ".npz" return output_path @@ -148,16 +148,18 @@ def preprocess_corpora(args, dictionary_cls=Dictionary): utils.maybe_parse_collection_argument(args.train_target_binary_path), str ): args.train_source_binary_path = maybe_generate_temp_file_path( - args.train_source_binary_path + args.train_source_binary_path, + is_npz=not args.fairseq_binary_data_format, ) args.train_target_binary_path = maybe_generate_temp_file_path( - args.train_target_binary_path + args.train_target_binary_path, + is_npz=not args.fairseq_binary_data_format, ) args.eval_source_binary_path = maybe_generate_temp_file_path( - args.eval_source_binary_path + args.eval_source_binary_path, is_npz=not args.fairseq_binary_data_format ) args.eval_target_binary_path = maybe_generate_temp_file_path( - args.eval_target_binary_path + args.eval_target_binary_path, is_npz=not args.fairseq_binary_data_format ) # Additional text preprocessing options could be added here before diff --git a/pytorch_translate/train.py b/pytorch_translate/train.py index d8a644c6..9c335407 100644 --- a/pytorch_translate/train.py +++ b/pytorch_translate/train.py @@ -300,6 +300,7 @@ def setup_training_model(args): src_bin_path=args.train_source_binary_path, tgt_bin_path=args.train_target_binary_path, weights_file=getattr(args, "train_weights_path", None), + is_npz=not args.fairseq_binary_data_format, ) if args.task == "dual_learning_task": @@ -311,6 +312,7 @@ def setup_training_model(args): split=args.valid_subset, src_bin_path=args.eval_source_binary_path, tgt_bin_path=args.eval_target_binary_path, + is_npz=not args.fairseq_binary_data_format, ) return task, model, criterion