diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3cd1c2f2e6..0c0bc711f0 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -69,7 +69,8 @@ unit_tests: - echo "Slurm job state $SLURM_STATE" - if [[ "$SLURM_STATE" != "COMPLETED" ]]; then echo "Slurm job did not complete. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs. Skipping pytest."; exit 1; fi - source $PYTHON_VIRTUAL_ENV - - pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py || echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs." + - cmd='pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py' + - if $cmd; then echo "Pytest succeded"; else echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs"; fi - echo "Completed the job" rules: - if: $TEST_LEVEL =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TEST_REGEX_ON_THIS_COMMIT @@ -134,7 +135,8 @@ unit_tests: if [[ $USE_TE -ne 1 ]]; then echo "Checking against ground truth file" export EXPECTED_METRICS_FILE=$BUILD_DIR/tests/functional_tests/test_results/$RUN_MODEL/$RUN_NAME.json - pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_ci_pipeline.py || echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs." + cmd='pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_ci_pipeline.py' + if $cmd; then echo "Pytest succeded"; else echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs"; fi fi - echo "Completed the job" rules: diff --git a/README.md b/README.md index 6bb334e8e1..cdb5bd3f07 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ The following table shows both model (MFU) and hardware (HFU) FLOPs utilization * [Datasets](#datasets) * [Collecting Wikipedia Training Data](#collecting-wikipedia-training-data) * [Collecting GPT Webtext Data](#collecting-gpt-webtext-data) + * [Reproducibility](#reproducibility) # Setup We strongly recommend using the latest release of [NGC's PyTorch container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) with DGX nodes. If you can't use this for some reason, use the latest pytorch, cuda, nccl, and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start) releases. Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks. @@ -365,7 +366,7 @@ See [megatron/text_generation_server.py](megatron/text_generation_server.py) for ### Detoxify GPT via Self-generation We include an example in `examples/detxoify_lm/` to detoxify language models by leveraging the generative power of language models. -See [examples/detxoify_lm/README.md](examples/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus. +See [examples/detxoify_lm/README.md](examples/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus. ## GPT Evaluation @@ -513,3 +514,13 @@ We recommend using the `--json` argument when using WikiExtractor, which will du ## Collecting GPT Webtext Data We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library from [jcpeterson](https://github.com/jcpeterson/openwebtext) and [eukaryote31's](https://github.com/eukaryote31/openwebtext) work to download urls. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in our [openwebtext](./tools/openwebtext) directory. For reddit URLs corresponding to content up to October 2018 we arrived at approximately 37GB of content. + +# Reproducibility +Megatron training is intended to be bitwise reproducible. This means that the same training config run twice in the same HW and SW environment should produce identical model checkpoints, losses and accuracy metric values (iteration time metrics may vary). + +There are currently three known Megatron optimizations that break reproducibility whilst still producing almost identical training runs. They are only applicable when using NGC containers >=22.05. The following workarounds should be applied in cases where reproducibility is required: +1. When training using the `--bf16` option the backward pass of `torch.nn.functional.embedding` is non-deterministic. If reproducibility is required you should also use the option `--embedding-weights-in-fp32`. The speed and memory impact of this change is negligible. +2. Also when training using `--bf16`, reproducbility is only obtained when the checkpointing and resume schedule of training is identical. If the checkpointing schedule will change, i.e. checkpointing and resume will occur at different iterations, the option `--no-bias-gelu-fusion` should be used. +3. Flash attention is non-deterministic. If reproducibility is required do not use `--use-flash-attn`. + +These sources of non-determinism are under active investigation. If you observe non-determinism in Megatron training under other circumstances please open an issue. diff --git a/megatron/arguments.py b/megatron/arguments.py index 644fbb7a51..671cdf270b 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -55,7 +55,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): # Args from environment args.rank = int(os.getenv('RANK', '0')) args.world_size = int(os.getenv("WORLD_SIZE", '1')) - + return args def validate_args(args, defaults={}): @@ -626,6 +626,8 @@ def _add_network_size_args(parser): help='Number of Experts in Switch Transformer (None means no Switch)') group.add_argument('--untie-embeddings-and-output-weights', action='store_true', help='Untie embeddings and output weights.'), + group.add_argument('--embedding-weights-in-fp32', action='store_true', + help='Cast word embedding weights to fp32 before embedding fwd.'), return parser @@ -1020,6 +1022,10 @@ def _add_distributed_args(parser): '--tensor-model-parallel-size instead.') group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') + group.add_argument('--overlap-p2p-communication', + action='store_true', + help='overlap pipeline parallel communication with forward and backward chunks', + dest='overlap_p2p_comm') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], help='Which backend to use for distributed training.') @@ -1212,6 +1218,8 @@ def __call__(self, parser, args, values, option_string=None): '1) a single data path, 2) multiple datasets in the' 'form: dataset1-weight dataset1-path dataset2-weight ' 'dataset2-path ...') + group.add_argument('--data-cache-path', default=None, + help='Path to a directory to hold cached index files.') group.add_argument('--vocab-size', type=int, default=None, help='Size of vocab before EOD or padding.') @@ -1385,14 +1393,14 @@ def _add_vision_args(parser): group.add_argument('--swin-backbone-type', type=str, default='tiny', choices=['tiny', 'base', 'h3'], help='pretraining objectives') - + # inpainting arguments group.add_argument('--mask-type', type=str, default='random', choices=['random', 'row'], help='mask types') group.add_argument('--mask-factor', type=float, default=1.0, help='mask size scaling parameter') - + # dino arguments group.add_argument('--iter-per-epoch', type=int, default=1250, help='iterations per epoch') diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 5a30619cd8..bc2b757c17 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -38,11 +38,15 @@ def check_checkpoint_args(checkpoint_args): arguments and the one retrieved from checkpoint.""" args = get_args() - def _compare(arg_name, old_arg_name=None): + def _compare(arg_name, old_arg_name=None, default=None): if old_arg_name is not None: - checkpoint_value = getattr(checkpoint_args, old_arg_name) + ckpt_arg_name = old_arg_name else: - checkpoint_value = getattr(checkpoint_args, arg_name) + ckpt_arg_name = arg_name + if default is not None: + checkpoint_value = getattr(checkpoint_args, ckpt_arg_name, default) + else: + checkpoint_value = getattr(checkpoint_args, ckpt_arg_name) args_value = getattr(args, arg_name) error_message = '{} value from checkpoint ({}) is not equal to the ' \ 'input argument value ({}).'.format( @@ -52,7 +56,7 @@ def _compare(arg_name, old_arg_name=None): _compare('num_layers') _compare('hidden_size') _compare('num_attention_heads') - _compare('add_position_embedding') + _compare('add_position_embedding', default=True) try: _compare('position_embedding_type') except AttributeError as e: diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index 301583132a..6a461ad8d4 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -9,6 +9,7 @@ from megatron import core from megatron.core.parallel_state import ( get_pipeline_model_parallel_group, + get_pipeline_model_parallel_rank, get_pipeline_model_parallel_prev_rank, get_pipeline_model_parallel_next_rank, ) @@ -63,28 +64,28 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev, tensor_recv_prev=recv_prev_shape_tensor, tensor_send_next=send_next_shape_tensor, tensor_recv_next=recv_next_shape_tensor, - group=mpu.get_pipeline_model_parallel_group()) + group=get_pipeline_model_parallel_group()) else: ops = [] if send_prev_shape_tensor is not None: send_prev_op = torch.distributed.P2POp( torch.distributed.isend, send_prev_shape_tensor, - mpu.get_pipeline_model_parallel_prev_rank()) + get_pipeline_model_parallel_prev_rank()) ops.append(send_prev_op) if recv_prev_shape_tensor is not None: recv_prev_op = torch.distributed.P2POp( torch.distributed.irecv, recv_prev_shape_tensor, - mpu.get_pipeline_model_parallel_prev_rank()) + get_pipeline_model_parallel_prev_rank()) ops.append(recv_prev_op) if send_next_shape_tensor is not None: send_next_op = torch.distributed.P2POp( torch.distributed.isend, send_next_shape_tensor, - mpu.get_pipeline_model_parallel_next_rank()) + get_pipeline_model_parallel_next_rank()) ops.append(send_next_op) if recv_next_shape_tensor is not None: recv_next_op = torch.distributed.P2POp( torch.distributed.irecv, recv_next_shape_tensor, - mpu.get_pipeline_model_parallel_next_rank()) + get_pipeline_model_parallel_next_rank()) ops.append(recv_next_op) if len(ops) > 0: reqs = torch.distributed.batch_isend_irecv(ops) @@ -105,12 +106,125 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev, return recv_prev_shape, recv_next_shape +def _batched_p2p_ops(*, + tensor_send_prev: Optional[torch.Tensor], + tensor_recv_prev: Optional[torch.Tensor], + tensor_send_next: Optional[torch.Tensor], + tensor_recv_next: Optional[torch.Tensor], + group: torch.distributed.ProcessGroup): + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_prev, + get_pipeline_model_parallel_prev_rank(), + group) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_prev, + get_pipeline_model_parallel_prev_rank(), + group) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_next, + get_pipeline_model_parallel_next_rank(), + group) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_next, + get_pipeline_model_parallel_next_rank(), + group) + ops.append(recv_next_op) + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + else: + reqs = [] + return reqs + +def _p2p_ops(*, + tensor_send_prev: Optional[torch.Tensor], + tensor_recv_prev: Optional[torch.Tensor], + tensor_send_next: Optional[torch.Tensor], + tensor_recv_next: Optional[torch.Tensor], + group: torch.distributed.ProcessGroup): + reqs = [] + rank = get_pipeline_model_parallel_rank() + if get_pipeline_model_parallel_rank() % 2 == 0: + if tensor_send_next is not None: + send_next_req = torch.distributed.isend( + tensor=tensor_send_next, + dst=get_pipeline_model_parallel_next_rank(), + group=group, + ) + reqs.append(send_next_req) + + if tensor_recv_prev is not None: + recv_prev_req = torch.distributed.irecv( + tensor=tensor_recv_prev, + src=get_pipeline_model_parallel_prev_rank(), + group=group, + ) + reqs.append(recv_prev_req) + + if tensor_send_prev is not None: + send_prev_req = torch.distributed.isend( + tensor=tensor_send_prev, + dst=get_pipeline_model_parallel_prev_rank(), + group=group, + ) + reqs.append(send_prev_req) + + if tensor_recv_next is not None: + recv_next_req = torch.distributed.irecv( + tensor=tensor_recv_next, + src=get_pipeline_model_parallel_next_rank(), + group=group, + ) + reqs.append(recv_next_req) + + else: + if tensor_recv_prev is not None: + recv_prev_req = torch.distributed.irecv( + tensor=tensor_recv_prev, + src=get_pipeline_model_parallel_prev_rank(), + group=group, + ) + reqs.append(recv_prev_req) + + if tensor_send_next is not None: + send_next_req = torch.distributed.isend( + tensor=tensor_send_next, + dst=get_pipeline_model_parallel_next_rank(), + group=group, + ) + reqs.append(send_next_req) + + if tensor_recv_next is not None: + recv_next_req = torch.distributed.irecv( + tensor=tensor_recv_next, + src=get_pipeline_model_parallel_next_rank(), + group=group, + ) + reqs.append(recv_next_req) + + if tensor_send_prev is not None: + send_prev_req = torch.distributed.isend( + tensor=tensor_send_prev, + dst=get_pipeline_model_parallel_prev_rank(), + group=group, + ) + reqs.append(send_prev_req) + return reqs def _communicate(*, tensor_send_next: Optional[torch.Tensor], tensor_send_prev: Optional[torch.Tensor], recv_prev: bool, recv_next: bool, tensor_shape: Shape, + batch_p2p_comm: bool = True, + wait_on_reqs: bool = True, dtype: Optional[torch.dtype], variable_seq_lengths: bool = False, use_ring_exchange_p2p: bool = False, @@ -136,6 +250,14 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor], tensors sent and received in a single function call are the same shape). + batch_p2p_comm (boolean, required): + If true use batch_isend_irecv, otherwise use individual + isend and irecv calls. + + wait_on_reqs (boolean, optional, default=False): + For non-batched p2p communication, wait on each request + before returning. + dtype (torch.dtype, required if either recv_{prev,next} is True): this must be the type of the tensors that will be received, will typically be params_dtype, but in the case @@ -167,6 +289,10 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor], tensor_recv_prev = None tensor_recv_next = None + # This will come from config in the next version, for now hard + # code it here to match existing functionality. + batch_p2p_sync = True + if not variable_seq_lengths: recv_prev_shape = tensor_shape recv_next_shape = tensor_shape @@ -204,46 +330,38 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor], # Send tensors in both the forward and backward directions as appropriate. if use_ring_exchange_p2p: - torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, - tensor_recv_prev=tensor_recv_prev, - tensor_send_next=tensor_send_next, - tensor_recv_next=tensor_recv_next, - group=get_pipeline_model_parallel_group()) + def _ring_exchange_wrapper(**kwargs): + torch.distributed.ring_exchange(**kwargs) + return [] + p2p_func = _ring_exchange_wrapper + elif batch_p2p_comm: + assert wait_on_reqs + p2p_func = _batched_p2p_ops else: - ops = [] - if tensor_send_prev is not None: - send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_prev, - get_pipeline_model_parallel_prev_rank()) - ops.append(send_prev_op) - if tensor_recv_prev is not None: - recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_prev, - get_pipeline_model_parallel_prev_rank()) - ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_next, - get_pipeline_model_parallel_next_rank()) - ops.append(send_next_op) - if tensor_recv_next is not None: - recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_next, - get_pipeline_model_parallel_next_rank()) - ops.append(recv_next_op) - if len(ops) > 0: - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() + p2p_func = _p2p_ops + + reqs = p2p_func(tensor_send_prev=tensor_send_prev, + tensor_recv_prev=tensor_recv_prev, + tensor_send_next=tensor_send_next, + tensor_recv_next=tensor_recv_next, + group=get_pipeline_model_parallel_group()) + + if wait_on_reqs and len(reqs) > 0: + for req in reqs: + req.wait() + reqs = None + + if batch_p2p_comm and batch_p2p_sync: # To protect against race condition when using batch_isend_irecv(). # User should assert that we have a modern enough PyTorch to not need this torch.cuda.synchronize() - return tensor_recv_prev, tensor_recv_next + return tensor_recv_prev, tensor_recv_next, reqs def recv_forward(tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, timers: Callable = None) -> torch.Tensor: """ Receive tensor from previous rank in pipeline (forward receive). @@ -256,12 +374,13 @@ def recv_forward(tensor_shape: Shape, else: if timers is not None: timers('forward-recv', log_level=2).start() - input_tensor, _ = _communicate( + input_tensor, _, _ = _communicate( tensor_send_next=None, tensor_send_prev=None, recv_prev=True, recv_next=False, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, dtype=dtype) if timers is not None: timers('forward-recv').stop() @@ -270,6 +389,7 @@ def recv_forward(tensor_shape: Shape, def recv_backward(tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, timers: Callable = None) -> torch.Tensor: """Receive tensor from next rank in pipeline (backward receive). @@ -280,12 +400,13 @@ def recv_backward(tensor_shape: Shape, else: if timers is not None: timers('backward-recv', log_level=2).start() - _, output_tensor_grad = _communicate( + _, output_tensor_grad, _ = _communicate( tensor_send_next=None, tensor_send_prev=None, recv_prev=False, recv_next=True, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, dtype=dtype) if timers is not None: timers('backward-recv').stop() @@ -293,6 +414,7 @@ def recv_backward(tensor_shape: Shape, def send_forward(output_tensor: torch.Tensor, + batch_p2p_comm: bool = True, timers: Callable = None) -> None: """Send tensor to next rank in pipeline (forward send). @@ -308,12 +430,14 @@ def send_forward(output_tensor: torch.Tensor, recv_prev=False, recv_next=False, tensor_shape=None, + batch_p2p_comm=batch_p2p_comm, dtype=None) if timers is not None: timers('forward-send').stop() def send_backward(input_tensor_grad: torch.Tensor, + batch_p2p_comm: bool = True, timers: Callable = None) -> None: """Send tensor to previous rank in pipeline (backward send). @@ -328,6 +452,7 @@ def send_backward(input_tensor_grad: torch.Tensor, recv_prev=False, recv_next=False, tensor_shape=None, + batch_p2p_comm=batch_p2p_comm, dtype=None) if timers is not None: timers('backward-send').stop() @@ -336,6 +461,7 @@ def send_backward(input_tensor_grad: torch.Tensor, def send_forward_recv_backward(output_tensor: torch.Tensor, tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, timers: Callable = None) -> torch.Tensor: """Batched send and recv with next rank in pipeline. @@ -346,12 +472,13 @@ def send_forward_recv_backward(output_tensor: torch.Tensor, else: if timers is not None: timers('forward-send-backward-recv', log_level=2).start() - _, output_tensor_grad = _communicate( + _, output_tensor_grad,_ = _communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=True, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, dtype=dtype) if timers is not None: timers('forward-send-backward-recv').stop() @@ -361,6 +488,7 @@ def send_forward_recv_backward(output_tensor: torch.Tensor, def send_backward_recv_forward(input_tensor_grad: torch.Tensor, tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, timers: Callable = None) -> torch.Tensor: """Batched send and recv with previous rank in pipeline. @@ -371,12 +499,13 @@ def send_backward_recv_forward(input_tensor_grad: torch.Tensor, else: if timers is not None: timers('backward-send-forward-recv', log_level=2).start() - input_tensor, _ = _communicate( + input_tensor, _, _ = _communicate( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=True, recv_next=False, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, dtype=dtype) if timers is not None: timers('backward-send-forward-recv').stop() @@ -387,6 +516,8 @@ def send_forward_recv_forward(output_tensor: torch.Tensor, recv_prev: bool, tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, + overlap_p2p_comm: bool = False, timers: Callable = None) -> torch.Tensor: """Batched recv from previous rank and send to next rank in pipeline. @@ -394,15 +525,19 @@ def send_forward_recv_forward(output_tensor: torch.Tensor, """ if timers is not None: timers('forward-send-forward-recv', log_level=2).start() - input_tensor, _ = _communicate( + input_tensor, _, wait_handles = _communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=recv_prev, recv_next=False, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, + wait_on_reqs=(not overlap_p2p_comm), dtype=dtype) if timers is not None: timers('forward-send-forward-recv').stop() + if overlap_p2p_comm: + return input_tensor, wait_handles return input_tensor @@ -410,6 +545,8 @@ def send_backward_recv_backward(input_tensor_grad: torch.Tensor, recv_next: bool, tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, + overlap_p2p_comm: bool = False, timers: Callable = None) -> torch.Tensor: """Batched recv from next rank and send to previous rank in pipeline. @@ -417,15 +554,19 @@ def send_backward_recv_backward(input_tensor_grad: torch.Tensor, """ if timers is not None: timers('backward-send-backward-recv', log_level=2).start() - _, output_tensor_grad = _communicate( + _, output_tensor_grad, wait_handles = _communicate( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, recv_next=recv_next, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, + wait_on_reqs=(not overlap_p2p_comm), dtype=dtype) if timers is not None: timers('backward-send-backward-recv').stop() + if overlap_p2p_comm: + return output_tensor_grad, wait_handles return output_tensor_grad @@ -436,6 +577,7 @@ def send_forward_backward_recv_forward_backward( recv_next: bool, tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, timers: Callable = None) -> Tuple[torch.Tensor, torch.Tensor]: """Batched send and recv with previous and next ranks in pipeline. @@ -444,12 +586,13 @@ def send_forward_backward_recv_forward_backward( if timers is not None: timers('forward-backward-send-forward-backward-recv', log_level=2).start() - input_tensor, output_tensor_grad = _communicate( + input_tensor, output_tensor_grad, _ = _communicate( tensor_send_next=output_tensor, tensor_send_prev=input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, dtype=dtype) if timers is not None: timers('forward-backward-send-forward-backward-recv').stop() diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 5007a44cd2..484d398fd8 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -85,6 +85,15 @@ def forward_step(data_iterator, model): tensor\_model\_parallel\_world\_size`. TODO: Do we need this? Just roll into tensor_shape arg? + overlap_p2p_comm (optional, default=False): When True + some of the peer to peer communication for pipeline + parallelism will overlap with computation. Must be False if + batch_p2p_comm is true. + + batch_p2p_comm (optional, default=True): When true use + batch_isend_irecv, otherwise use individual isend and irecv + calls. Must be false if overlap_p2p_comm is True. + forward_only (optional, default=False): Perform only the forward step timers (optional, default=None): TODO @@ -94,11 +103,11 @@ def forward_step(data_iterator, model): enable_autocast (optional, default=False): If True, runs the forward_step_func call inside torch.autocast context - deallocate_pipeline_outputs (optional, default=False): If True, output data + deallocate_pipeline_outputs (optional, default=False): If True, output data is deallocated after the tensor is sent to the next pipeline stage. - Helps with saving memory, does nothing when pipeline parallel is + Helps with saving memory, does nothing when pipeline parallel is not used. - + no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel communication. If the model is an instance of torch.nn.DistributedDataParallel, the @@ -276,8 +285,8 @@ def backward_step(grad_scaler, input_tensor, output_tensor, # Backward pass. if output_tensor_grad[0] is None and grad_scaler is not None: - output_tensor = grad_scaler(output_tensor[0]) - + output_tensor[0] = grad_scaler(output_tensor[0]) + if deallocate_pipeline_outputs: custom_backward(output_tensor[0], output_tensor_grad[0]) else: @@ -319,6 +328,8 @@ def forward_backward_no_pipelining(*, decoder_seq_length: Optional[int] = None, # unused grad_scaler: Callable = None, sequence_parallel: bool = False, # unused + overlap_p2p_comm: bool = False, # unused + batch_p2p_comm: bool = True, # unused forward_only: bool = False, timers: Callable = None, collect_non_loss_data: bool = False, @@ -387,6 +398,8 @@ def forward_backward_pipelining_with_interleaving(*, decoder_seq_length: Optional[int] = None, grad_scaler: Callable = None, sequence_parallel: bool = False, + overlap_p2p_comm: bool = False, + batch_p2p_comm: bool = True, forward_only: bool = False, timers: Callable = None, collect_non_loss_data: bool = False, @@ -407,6 +420,9 @@ def forward_backward_pipelining_with_interleaving(*, assert isinstance(data_iterator, list), \ "interleaved pipeline parallelism expected each model chunk to have a data iterator" + if overlap_p2p_comm and batch_p2p_comm: + raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") + # Disable async grad reductions if no_sync_func is None and all(isinstance(chunk, torchDDP) for chunk in model): def multi_no_sync(): @@ -507,7 +523,7 @@ def get_model_chunk_id(microbatch_id, forward): def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: """Check if an iteration is the first for a model chunk.""" microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = num_microbatches // microbatch_group_size + num_microbatch_groups = total_num_microbatches // microbatch_group_size microbatch_group_id = microbatch_id // microbatch_group_size microbatch_id_in_group = microbatch_id % microbatch_group_size if microbatch_group_id == 0: @@ -518,7 +534,7 @@ def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: """Check if an iteration is the last for a model chunk.""" microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = num_microbatches // microbatch_group_size + num_microbatch_groups = total_num_microbatches // microbatch_group_size microbatch_group_id = microbatch_id // microbatch_group_size microbatch_id_in_group = microbatch_id % microbatch_group_size if microbatch_group_id == num_microbatch_groups - 1: @@ -617,8 +633,20 @@ def backward_step_helper(microbatch_id): # Run warmup forward passes. parallel_state.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( - p2p_communication.recv_forward(tensor_shape, dtype, timers=timers)) + p2p_communication.recv_forward(tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_comm, + timers=timers)) + + fwd_wait_handles = None + bwd_wait_handles = None + for k in range(num_warmup_microbatches): + + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + output_tensor = forward_step_helper(k) # Determine if tensor should be received from previous stage. @@ -636,91 +664,216 @@ def backward_step_helper(microbatch_id): # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). - if k == (num_warmup_microbatches - 1) and not forward_only and \ - not all_warmup_microbatches: - input_tensor_grad = None - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False - input_tensor, output_tensor_grad = \ - p2p_communication.send_forward_backward_recv_forward_backward( + if not overlap_p2p_comm: + if k == (num_warmup_microbatches - 1) and not forward_only and \ + not all_warmup_microbatches: + input_tensor_grad = None + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + input_tensor, output_tensor_grad = \ + p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, - tensor_shape=tensor_shape, dtype=dtype, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_comm, timers=timers) - output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) + output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) + else: + input_tensor = \ + p2p_communication.send_forward_recv_forward( + output_tensor, recv_prev=recv_prev, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_comm, + timers=timers) + input_tensors[next_forward_model_chunk_id].append(input_tensor) else: - input_tensor = \ + input_tensor, fwd_wait_handles = \ p2p_communication.send_forward_recv_forward( output_tensor, recv_prev=recv_prev, - tensor_shape=tensor_shape, dtype=dtype, - timers=timers) - input_tensors[next_forward_model_chunk_id].append(input_tensor) + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_comm, + timers=timers, + overlap_p2p_comm=True) + + if k == (num_warmup_microbatches - 1) and not forward_only and \ + not all_warmup_microbatches: + input_tensor_grad = None + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + + output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward( + input_tensor_grad, recv_next=recv_next, + tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, + dtype=dtype, + timers=timers, + overlap_p2p_comm=True) + + output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) + input_tensors[next_forward_model_chunk_id].append(input_tensor) + deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) # Run 1F1B in steady state. for k in range(num_microbatches_remaining): # Forward pass. forward_k = k + num_warmup_microbatches - output_tensor = forward_step_helper(forward_k) - # Backward pass. - backward_k = k - input_tensor_grad = backward_step_helper(backward_k) + if overlap_p2p_comm: + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + + deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) + + output_tensor = forward_step_helper(forward_k) + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + + # Last virtual stage no activation tensor to send + if parallel_state.is_pipeline_last_stage(): + output_tensor = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, + forward=True) - # Send output_tensor and input_tensor_grad, receive input_tensor - # and output_tensor_grad. + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False - # Determine if current stage has anything to send in either direction, - # otherwise set tensor to None. - forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) - parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) - if parallel_state.is_pipeline_last_stage(): - output_tensor = None + # Send activation tensor to the next stage and receive activation tensor from the + # previous stage + input_tensor, fwd_wait_handles = \ + p2p_communication.send_forward_recv_forward( + output_tensor, recv_prev=recv_prev, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_comm, + timers=timers, + overlap_p2p_comm=True) + # assert fwd_wait_handles is not None - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) - parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) - if parallel_state.is_pipeline_first_stage(): - input_tensor_grad = None + if bwd_wait_handles is not None: + for req in bwd_wait_handles: + req.wait() - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, - forward=True) - - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, - forward=False) + # Backward pass. + backward_k = k + input_tensor_grad = backward_step_helper(backward_k) - # If last iteration, don't receive; we already received one extra - # before the start of the for loop. - if k == (num_microbatches_remaining - 1): - recv_prev = False + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + + # First virtual stage no activation gradient tensor to send + if parallel_state.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if the current virtual stage has an activation gradient tensor to receive + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id( + backward_k + 1, forward=False + ) + + output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward( + input_tensor_grad, recv_next=recv_next, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_comm, + timers=timers, + overlap_p2p_comm=True) + + else: # no p2p overlap + output_tensor = forward_step_helper(forward_k) + + # Backward pass. + backward_k = k + input_tensor_grad = backward_step_helper(backward_k) + + # Send output_tensor and input_tensor_grad, receive input_tensor + # and output_tensor_grad. + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + if parallel_state.is_pipeline_last_stage(): + output_tensor = None + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + if parallel_state.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, + forward=True) - # Communicate tensors. - input_tensor, output_tensor_grad = \ - p2p_communication.send_forward_backward_recv_forward_backward( + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, + forward=False) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + # Communicate tensors. + input_tensor, output_tensor_grad = \ + p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, - tensor_shape=tensor_shape, dtype=dtype, timers=timers) - deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_comm, + timers=timers) + deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) # Put input_tensor and output_tensor_grad in data structures in the # right location. @@ -730,11 +883,20 @@ def backward_step_helper(microbatch_id): output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grad) + deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) + # Run cooldown backward passes (flush out pipeline). if not forward_only: + if overlap_p2p_comm and bwd_wait_handles is not None: + for wait_handle in bwd_wait_handles: + wait_handle.wait() + if all_warmup_microbatches: output_tensor_grads[num_model_chunks-1].append( - p2p_communication.recv_backward(tensor_shape, dtype=dtype, timers=timers)) + p2p_communication.recv_backward(tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_comm, + timers=timers)) for k in range(num_microbatches_remaining, total_num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) @@ -747,7 +909,9 @@ def backward_step_helper(microbatch_id): output_tensor_grads[next_backward_model_chunk_id].append( p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, - tensor_shape=tensor_shape, dtype=dtype, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_comm, timers=timers)) # Launch any remaining grad reductions @@ -881,6 +1045,8 @@ def forward_backward_pipelining_without_interleaving(*, decoder_seq_length: Optional[int] = None, grad_scaler: Callable = None, sequence_parallel: bool = False, + overlap_p2p_comm: bool = False, + batch_p2p_comm: bool = True, forward_only: bool = False, timers: Callable = None, collect_non_loss_data: bool = False, @@ -904,6 +1070,12 @@ def forward_backward_pipelining_without_interleaving(*, "non-pipeline-parallel schedule does not support model chunking" data_iterator = data_iterator[0] + if overlap_p2p_comm: + raise ValueError("Non-interleaved pipeline parallelism does not support overlapping p2p communication") + + if not batch_p2p_comm: + raise ValueError("Non-interleaved pipeline parallelism only supports using batched p2p communication") + # Disable async grad reductions if no_sync_func is None and isinstance(model, torchDDP): no_sync_func = model.no_sync diff --git a/megatron/data/blendable_dataset.py b/megatron/data/blendable_dataset.py index 453b362f3e..8ff5ce3da8 100644 --- a/megatron/data/blendable_dataset.py +++ b/megatron/data/blendable_dataset.py @@ -2,17 +2,21 @@ """Blendable dataset.""" +import hashlib +import os import time import numpy as np import torch from megatron import print_rank_0 +from megatron.core import mpu class BlendableDataset(torch.utils.data.Dataset): - def __init__(self, datasets, weights, size): + def __init__(self, datasets, weights, size, *, + data_cache_path=None): self.datasets = datasets num_datasets = len(datasets) @@ -27,18 +31,74 @@ def __init__(self, datasets, weights, size): weights /= sum_weights # Build indicies. - start_time = time.time() - assert num_datasets < 255 - self.dataset_index = np.zeros(self.size, dtype=np.uint8) - self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) - - from megatron.data import helpers - helpers.build_blending_indices(self.dataset_index, - self.dataset_sample_index, - weights, num_datasets, self.size, - torch.distributed.get_rank() == 0) - print_rank_0('> elapsed time for building blendable dataset indices: ' - '{:.2f} (sec)'.format(time.time() - start_time)) + def _build_indices(): + start_time = time.time() + assert num_datasets < 255 + dataset_index = np.zeros(self.size, dtype=np.uint8) + dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + from megatron.data import helpers + helpers.build_blending_indices(dataset_index, dataset_sample_index, + weights, num_datasets, self.size, + torch.distributed.get_rank() == 0) + print_rank_0('> elapsed time for building blendable dataset indices: ' + '{:.2f} (sec)'.format(time.time() - start_time)) + return dataset_index, dataset_sample_index + + desc = "Blendable dataset\n\n" + desc += "Datasets:\n" + for dataset in datasets: + desc += dataset.desc + "\n\n" + desc += f"Weights: {weights}\n" + desc += f"Size: {size}\n" + self.desc = desc + + if data_cache_path: + desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest() + desc_path = os.path.join(data_cache_path, desc_hash + ".dsc") + index_path = os.path.join(data_cache_path, desc_hash + "_index.npy") + sample_index_path = os.path.join(data_cache_path, desc_hash + "_sample_index.npy") + cache_hit = os.path.isfile(index_path) and os.path.isfile(sample_index_path) + cache_success = True + if torch.distributed.get_rank() == 0 and not cache_hit: + print(' > WARNING: could not find index map files for blendable' + ' dataset, building indices on rank 0 ...', flush=True) + dataset_index, dataset_sample_index = _build_indices() + try: + os.makedirs(os.path.dirname(index_path), exist_ok=True) + with open(desc_path, 'wt') as fd: + fd.write(desc) + np.save(index_path, dataset_index, allow_pickle=True) + np.save(sample_index_path, dataset_sample_index, + allow_pickle=True) + except OSError: + print(f'There was an error trying to create the data cache directory ({data_cache_path})') + print('or a file in it. This is set with the --data-cache-path argument. Please') + print('ensure you have write access to this directory or specify one that you do have') + print('write access to.') + cache_success = False + + + counts = torch.cuda.LongTensor([cache_success]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + if counts[0].item() != ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())): + print_rank_0("Data index creation unsuccessful, exiting.") + exit() + + # Load on all ranks. + print_rank_0(f'> loading blendable dataset index: {index_path}') + self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode='r') + assert self.dataset_index.size == self.size + + print_rank_0(f'> loading blendable dataset sample index: {sample_index_path}') + self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode='r') + assert self.dataset_sample_index.size == self.size + else: + self.dataset_index, self.dataset_sample_index = _build_indices() + # Check size _ = self.__getitem__(self.size - 1) diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index 183f3cd460..4a572a9551 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -2,7 +2,7 @@ """GPT style dataset.""" -import itertools +import hashlib import os import time @@ -24,7 +24,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_data_prefix=None, valid_data_prefix=None, test_data_prefix=None, - return_doc_ids=False): + return_doc_ids=False, *, + data_cache_path=None): """Build train, valid, and test datasets.""" # Single dataset. @@ -33,7 +34,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets(data_prefix[0], data_impl, splits_string, train_valid_test_num_samples, - seq_length, seed, skip_warmup) + seq_length, seed, skip_warmup, + data_cache_path=data_cache_path) # Blending dataset. elif data_prefix: print_rank_0("Blending dataset for train, valid & test") @@ -55,18 +57,19 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, prefixes[i], data_impl, splits_string, datasets_train_valid_test_num_samples[i], seq_length, seed, skip_warmup, - return_doc_ids) + return_doc_ids, + data_cache_path=data_cache_path) if train_ds: train_datasets.append(train_ds) if valid_ds: valid_datasets.append(valid_ds) if test_ds: test_datasets.append(test_ds) - all_train_datasets = BlendableDataset(train_datasets, weights, train_num_samples) \ + all_train_datasets = BlendableDataset(train_datasets, weights, train_num_samples, data_cache_path=data_cache_path) \ if train_datasets else None - all_valid_datasets = BlendableDataset(valid_datasets, weights, valid_num_samples) \ + all_valid_datasets = BlendableDataset(valid_datasets, weights, valid_num_samples, data_cache_path=data_cache_path) \ if valid_datasets else None - all_test_datasets = BlendableDataset(test_datasets, weights, test_num_samples) \ + all_test_datasets = BlendableDataset(test_datasets, weights, test_num_samples, data_cache_path=data_cache_path) \ if test_datasets else None else: print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.") @@ -75,18 +78,25 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, # Single dataset. if train_data_prefix is not None: train_dataset = build_dataset("train", train_data_prefix, data_impl, + splits_string, train_valid_test_num_samples[0], - seq_length, seed, skip_warmup) + seq_length, seed, skip_warmup, + data_cache_path=data_cache_path) if valid_data_prefix is not None: valid_dataset = build_dataset("valid", valid_data_prefix, data_impl, + splits_string, train_valid_test_num_samples[1], - seq_length, seed, False) + seq_length, seed, False, + data_cache_path=data_cache_path) + if test_data_prefix is not None: test_dataset = build_dataset("test", test_data_prefix, data_impl, + splits_string, train_valid_test_num_samples[2], - seq_length, seed, False) + seq_length, seed, False, + data_cache_path=data_cache_path) return (train_dataset, valid_dataset, test_dataset) @@ -96,7 +106,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl, train_valid_test_num_samples, - seq_length, seed, skip_warmup, train_valid_test): + seq_length, seed, skip_warmup, train_valid_test, + data_cache_path=None): ''' Build a single dataset group corresponding to Option 2 of data loading see arguments.py a dataset group is passed on the following form @@ -115,7 +126,8 @@ def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl, data_impl, train_valid_test_num_samples, seq_length, seed, skip_warmup, - dataset_group_name, train_valid_test) + dataset_group_name, train_valid_test, + data_cache_path=data_cache_path) return dataset # Blending dataset. else: @@ -127,7 +139,7 @@ def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl, data_prefix += [w,p] output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. @@ -139,19 +151,25 @@ def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl, datasets_train_valid_test_num_samples[i], seq_length, seed, skip_warmup, - dataset_group_name, train_valid_test) + dataset_group_name, train_valid_test, + data_cache_path=data_cache_path) # ds can be none if the dataset is so small that not a single document # is present in the split. assert ds is not None, \ f"Got an empty split when trying to create dataset: {prefixes[i], splits[i]}" datasets.append(ds) - all_datasets = BlendableDataset(datasets, weights, train_valid_test_num_samples[index]) + all_datasets = BlendableDataset(datasets, + weights, + train_valid_test_num_samples[index], + data_cache_path=data_cache_path) return all_datasets -def _build_single_datasets(data_prefix, range_string, data_impl, train_valid_test_num_samples, - seq_length, seed, skip_warmup, dataset_group_name, train_valid_test): +def _build_single_datasets(data_prefix, range_string, data_impl, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, dataset_group_name, train_valid_test, + data_cache_path=None): """Build a single dataset""" assert train_valid_test in ["train","valid","test"] @@ -183,8 +201,10 @@ def build_dataset(name): step=1, dtype=np.int32) dataset = GPTDataset(name, data_prefix, documents, indexed_dataset, + range_string, train_valid_test_num_samples[index], - seq_length, seed) + seq_length, seed, + data_cache_path=data_cache_path) return dataset dataset = build_dataset(dataset_group_name) @@ -195,7 +215,8 @@ def build_dataset(name): def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup, - return_doc_ids=False): + return_doc_ids=False, *, + data_cache_path=None): """Build train, valid, and test datasets.""" # Indexed dataset. @@ -224,11 +245,12 @@ def build_dataset(index, name): if splits[index + 1] > splits[index]: documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) - dataset = GPTDataset(name, data_prefix, - documents, indexed_dataset, + dataset = GPTDataset(name, data_prefix, documents, indexed_dataset, + splits_string, train_valid_test_num_samples[index], seq_length, seed, - return_doc_ids) + return_doc_ids, + data_cache_path=data_cache_path) return dataset train_dataset = build_dataset(0, 'train') @@ -238,14 +260,17 @@ def build_dataset(index, name): return (train_dataset, valid_dataset, test_dataset) -def build_dataset(dataset_name, data_prefix, data_impl, num_samples, - seq_length, seed, skip_warmup): +def build_dataset(dataset_name, data_prefix, data_impl, + splits_string, num_samples, + seq_length, seed, skip_warmup, + *, + data_cache_path=None): dataset = None if len(data_prefix) == 1: - dataset = _build_dataset(dataset_name, - data_prefix[0], data_impl, - num_samples, seq_length, - seed, skip_warmup) + dataset = _build_dataset(dataset_name, data_prefix[0], data_impl, + splits_string, num_samples, seq_length, + seed, skip_warmup, + data_cache_path=data_cache_path) else: # Blending dataset. # Parse the values. @@ -256,20 +281,24 @@ def build_dataset(dataset_name, data_prefix, data_impl, num_samples, # Build individual datasets. datasets = [] for i in range(len(prefixes)): - ds = _build_dataset(dataset_name, prefixes[i], - data_impl, dataset_num_samples[i], - seq_length, seed, skip_warmup) + ds = _build_dataset(dataset_name, prefixes[i], data_impl, + splits_string, dataset_num_samples[i], + seq_length, seed, skip_warmup, + data_cache_path=data_cache_path) if ds: datasets.append(ds) if datasets: - dataset = BlendableDataset(datasets, weights, num_samples) + dataset = BlendableDataset(datasets, weights, num_samples, + data_cache_path=data_cache_path) return dataset -def _build_dataset(dataset_name, data_prefix, data_impl, - num_samples, seq_length, seed, skip_warmup): +def _build_dataset(dataset_name, data_prefix, data_impl, splits_string, + num_samples, seq_length, seed, skip_warmup, + *, + data_cache_path=None): """ Build dataset. This method is called when individual train, valid, test datasets are provided @@ -289,9 +318,9 @@ def _build_dataset(dataset_name, data_prefix, data_impl, documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) - dataset = GPTDataset(dataset_name, data_prefix, - documents, indexed_dataset, - num_samples, seq_length, seed) + dataset = GPTDataset(dataset_name, data_prefix, documents, indexed_dataset, + splits_string, num_samples, seq_length, seed, + data_cache_path=data_cache_path) return dataset @@ -315,8 +344,9 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): class GPTDataset(torch.utils.data.Dataset): def __init__(self, name, data_prefix, documents, indexed_dataset, - num_samples, seq_length, seed, - return_doc_ids=False): + splits_string, num_samples, seq_length, seed, + return_doc_ids=False, *, + data_cache_path=None): self.name = name self.indexed_dataset = indexed_dataset @@ -327,10 +357,11 @@ def __init__(self, name, data_prefix, documents, indexed_dataset, assert np.max(documents) < indexed_dataset.sizes.shape[0] # Build index mappings. - self.doc_idx, self.sample_idx, self.shuffle_idx, self.index_prefix = \ + self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc, self.desc_hash = \ _build_index_mappings(self.name, data_prefix, documents, self.indexed_dataset.sizes, - num_samples, seq_length, seed) + splits_string, num_samples, seq_length, seed, + data_cache_path=data_cache_path) self.args = get_args() self.tokenizer = get_tokenizer() @@ -435,7 +466,9 @@ def __getitem__(self, idx): def _build_index_mappings(name, data_prefix, documents, sizes, - num_samples, seq_length, seed): + splits_string, num_samples, seq_length, seed, + *, + data_cache_path): """Build doc-idx, sample-idx, and shuffle-idx. doc-idx: is an array (ordered) of documents to be used in training. sample-idx: is the start document index and document offset for each @@ -451,21 +484,45 @@ def _build_index_mappings(name, data_prefix, documents, sizes, np_rng = np.random.RandomState(seed=seed) # Filename of the index mappings. - index_prefix = '{}_indexmap'.format(name) - index_prefix += '_{}ns'.format(num_samples) - index_prefix += '_{}sl'.format(seq_length) - index_prefix += '_{}s'.format(seed) - _filename = data_prefix + '_' + index_prefix - doc_idx_filename = _filename + '_doc_idx.npy' - sample_idx_filename = _filename + '_sample_idx.npy' - shuffle_idx_filename = _filename + '_shuffle_idx.npy' + desc = "GPT Dataset\n\n" + desc += f"Data prefix {data_prefix}\n" + desc += f"Dataset name {name}\n" + desc += f"Number of samples {num_samples}\n" + desc += f"Sequence length {seq_length}\n" + desc += f"Random seed {seed}\n" + desc += f"Split {splits_string}\n" + desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest() + desc_filename = desc_hash + ".dsc" + doc_idx_filename = desc_hash + '_doc_idx.npy' + sample_idx_filename = desc_hash + '_sample_idx.npy' + shuffle_idx_filename = desc_hash + '_shuffle_idx.npy' + + # Look for cache in main data dir first to avoid unnecessary + # duplication, then look in data-cache-path if specified, + # If nothing is found, use the last path looked in + build_indices = True + prefixes = [os.path.join(os.path.dirname(data_prefix), 'index-cache')] + if data_cache_path is not None: + prefixes.append(data_cache_path) + for prefix in prefixes: + idx_path = { + 'desc': os.path.join(prefix, desc_filename), + 'doc': os.path.join(prefix, doc_idx_filename), + 'sample': os.path.join(prefix, sample_idx_filename), + 'shuffle': os.path.join(prefix, shuffle_idx_filename) + } + for f in idx_path.values(): + if not os.path.isfile(f): + break + else: + # Found our files! + build_indices = False + break + data_cache_dir = os.path.dirname(idx_path['desc']) + data_cache_success = True # Build the indexed mapping if not exist. - if torch.distributed.get_rank() == 0 and \ - (not os.path.isfile(doc_idx_filename) or - not os.path.isfile(sample_idx_filename) or - not os.path.isfile(shuffle_idx_filename)): - + if build_indices and torch.distributed.get_rank() == 0: print_rank_0(' > WARNING: could not find index map files, building ' 'the indices on rank 0 ...') @@ -490,7 +547,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes, num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length # For very small datasets, `last_epoch_num_samples` can be equal to # (num_samples_per_epoch + 1). - # TODO: check that this is not problematic indeed assert last_epoch_num_samples <= (num_samples_per_epoch + 1), \ 'last epoch number of samples exceeded max value.' # If we have less than 80% of the samples for the last epoch, @@ -510,67 +566,81 @@ def _build_index_mappings(name, data_prefix, documents, sizes, print(string.format(last_epoch_num_samples, num_samples_per_epoch), flush=True) - # doc-idx. - start_time = time.time() - doc_idx = _build_doc_idx(documents, num_epochs, np_rng, - separate_last_epoch) - np.save(doc_idx_filename, doc_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save doc-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) - # sample-idx. - start_time = time.time() - # Use C++ implementation for speed. - # First compile and then import. - from megatron.data import helpers - assert doc_idx.dtype == np.int32 - assert sizes.dtype == np.int32 - sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, - num_epochs, tokens_per_epoch) - np.save(sample_idx_filename, sample_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save sample-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) - # shuffle-idx. - start_time = time.time() - # -1 is due to data structure used to retieve the index: - # sample i --> [sample_idx[i], sample_idx[i+1]) - if separate_last_epoch: - num_samples_ = num_samples_from_epochs_minus_one - else: - num_samples_ = sample_idx.shape[0] - 1 - shuffle_idx = _build_shuffle_idx(num_samples_, - sample_idx.shape[0] - 1, np_rng) - np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save shuffle-idx mapping' - ' (seconds): {:4f}'.format(time.time() - start_time)) - - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) + + try: + os.makedirs(data_cache_dir, exist_ok=True) + + # description + with open(idx_path['desc'], 'wt') as fd: + fd.write(desc) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, + separate_last_epoch) + np.save(idx_path['doc'], doc_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save doc-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + from megatron.data import helpers + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, + num_epochs, tokens_per_epoch) + np.save(idx_path['sample'], sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save sample-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, + sample_idx.shape[0] - 1, np_rng) + np.save(idx_path['shuffle'], shuffle_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save shuffle-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) + except OSError: + print(f'There was an error trying to create the data cache directory ({data_cache_dir})') + print('or a file in it. This defaults to a directory "index-cache" within the directory') + print('the data files are in and can be set with the --data-cache-path argument. Please') + print('ensure you have write access to this directory or specify one that you do have') + print('write access to.') + data_cache_success = False + + counts = torch.cuda.LongTensor([data_cache_success]) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( + if counts[0].item() != ( torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())): + print_rank_0("Data index creation unsuccessful, exiting.") + exit() # Load mappings. start_time = time.time() - print_rank_0(' > loading doc-idx mapping from {}'.format( - doc_idx_filename)) - doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') - print_rank_0(' > loading sample-idx mapping from {}'.format( - sample_idx_filename)) - sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') - print_rank_0(' > loading shuffle-idx mapping from {}'.format( - shuffle_idx_filename)) - shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(f" > loading doc-idx mapping from {idx_path['doc']}") + doc_idx = np.load(idx_path['doc'], allow_pickle=True, mmap_mode='r') + + print_rank_0(f" > loading sample-idx mapping from {idx_path['sample']}") + sample_idx = np.load(idx_path['sample'], allow_pickle=True, mmap_mode='r') + + print_rank_0(f" > loading shuffle-idx mapping from {idx_path['shuffle']}") + shuffle_idx = np.load(idx_path['shuffle'], allow_pickle=True, mmap_mode='r') + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( time.time() - start_time)) print_rank_0(' total number of samples: {}'.format( sample_idx.shape[0])) print_rank_0(' total number of epochs: {}'.format(num_epochs)) - return doc_idx, sample_idx, shuffle_idx, index_prefix + return doc_idx, sample_idx, shuffle_idx, desc, desc_hash def _num_tokens(documents, sizes): diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 831c7de6c3..ebe3fab81a 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -95,9 +95,9 @@ def write_longs(f, a): 3: np.int16, 4: np.int32, 5: np.int64, - 6: np.float32, - 7: np.float64, - 8: np.uint16 + 6: np.float64, + 7: np.float32, + 8: np.uint16, } @@ -269,7 +269,7 @@ class IndexedDatasetBuilder(object): np.int32: 4, np.int64: 8, np.float32: 4, - np.float64: 8 + np.float64: 8, } def __init__(self, out_file, dtype=np.int32): diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index bd589fe6b6..fd19a74c6f 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -129,6 +129,10 @@ class Embedding(MegatronModule): init_method: weight initialization method num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding + embedding_weights_in_fp32: casts word embedding weights to + fp32 before sampling. Required to + maintain reproducibility when + training in bf16. """ def __init__(self, @@ -136,7 +140,8 @@ def __init__(self, vocab_size, embedding_dropout_prob, init_method, - num_tokentypes=0): + num_tokentypes=0, + embedding_weights_in_fp32=False): super(Embedding, self).__init__() self.hidden_size = hidden_size @@ -146,12 +151,14 @@ def __init__(self, args = get_args() # Word embeddings (parallel). + self.embedding_weights_in_fp32 = embedding_weights_in_fp32 + self.params_dtype = args.params_dtype self.word_embeddings = tensor_parallel.VocabParallelEmbedding( vocab_size, self.hidden_size, init_method=self.init_method, params_dtype=args.params_dtype, use_cpu_initialization=args.use_cpu_initialization, - perform_initialization=args.perform_initialization + perform_initialization=args.perform_initialization, ) self._word_embeddings_key = 'word_embeddings' @@ -184,7 +191,7 @@ def __init__(self, else: self.tokentype_embeddings = None - self.fp32_residual_connection = args.fp32_residual_connection + self.fp32_residual_connection = args.fp32_residual_connection self.sequence_parallel = args.sequence_parallel # Embeddings dropout self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) @@ -219,7 +226,12 @@ def add_tokentype_embeddings(self, num_tokentypes): def forward(self, input_ids, position_ids, tokentype_ids=None): # Embeddings. + if self.embedding_weights_in_fp32: + self.word_embeddings = self.word_embeddings.to(torch.float32) words_embeddings = self.word_embeddings(input_ids) + if self.embedding_weights_in_fp32: + words_embeddings = words_embeddings.to(self.params_dtype) + self.word_embeddings = self.word_embeddings.to(self.params_dtype) if self.add_position_embedding and self.position_embedding_type == PositionEmbeddingType.absolute: assert self.position_embeddings is not None position_embeddings = self.position_embeddings(position_ids) @@ -365,7 +377,8 @@ def __init__(self, args.padded_vocab_size, args.hidden_dropout, self.init_method, - self.num_tokentypes) + self.num_tokentypes, + args.embedding_weights_in_fp32) self._embedding_key = 'embedding' # Rotary positional embeddings diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 3c1387a8c4..e47eaa45ea 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1608,6 +1608,8 @@ def __init__(self, init_method, output_layer_init_method, # Transformer layers. if args.retro_add_retriever: + assert self.recompute_granularity != 'full', \ + "Full recompute not supported for Retro." assert args.transformer_impl == 'local', \ "Transformer engine does not support Retro layers." def build_layer(layer_number): @@ -1758,8 +1760,9 @@ def custom_forward(*args, **kwargs): hidden_states = tensor_parallel.checkpoint( custom(l, l + self.recompute_num_layers), self.distribute_saved_activations, - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, rotary_pos_emb) + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) l += self.recompute_num_layers @@ -1781,8 +1784,9 @@ def custom_forward(*args, **kwargs): hidden_states = tensor_parallel.checkpoint( custom(l, l + 1), self.distribute_saved_activations, - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, rotary_pos_emb) + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) else: if self.transformer_impl == 'transformer_engine': hidden_states = custom(l, l + 1)( @@ -1790,8 +1794,9 @@ def custom_forward(*args, **kwargs): enc_dec_attn_mask, **te_forward_kwargs) else: hidden_states = custom(l, l + 1)( - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, rotary_pos_emb) + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) else: raise ValueError("Invalid activation recompute method.") @@ -1873,8 +1878,6 @@ def forward(self, hidden_states, attention_mask, # Forward pass. if self.recompute_granularity == 'full': - assert not self.retro_add_retriever, \ - "full recompute not supported for retro." hidden_states = self._checkpointed_forward(hidden_states, attention_mask, encoder_output, diff --git a/megatron/training.py b/megatron/training.py index 35e47c20d6..7b17f18c1d 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -437,6 +437,8 @@ def train_step(forward_step_func, data_iterator, tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size), grad_scaler=optimizer.scale_loss, sequence_parallel=args.sequence_parallel, + overlap_p2p_comm=args.overlap_p2p_comm, + batch_p2p_comm=not args.overlap_p2p_comm, forward_only=False, timers=fwd_bwd_timers) timers('forward-backward').stop() @@ -937,9 +939,35 @@ def cyclic_iter(iter): yield x +def build_train_valid_test_datasets(build_train_valid_test_datasets_provider): + """Build pretraining datasets.""" + + args = get_args() + + # Number of train/valid/test samples. + if args.train_samples: + train_samples = args.train_samples + else: + train_samples = args.train_iters * args.global_batch_size + eval_iters = (args.train_iters // args.eval_interval + 1) * \ + args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size] + print_rank_0(' > datasets target sizes (minimum size):') + print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) + print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) + print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) + + # Build the datasets. + return build_train_valid_test_datasets_provider(train_val_test_num_samples) + + def build_train_valid_test_data_loaders( build_train_valid_test_datasets_provider): - """XXX""" + """Build pretraining data loaders.""" + args = get_args() (train_dataloader, valid_dataloaders, test_dataloaders) = (None, None, None) @@ -959,25 +987,9 @@ def build_train_valid_test_data_loaders( # Data loader only on rank 0 of each model parallel group. if mpu.get_tensor_model_parallel_rank() == 0: - # Number of train/valid/test samples. - if args.train_samples: - train_samples = args.train_samples - else: - train_samples = args.train_iters * args.global_batch_size - eval_iters = (args.train_iters // args.eval_interval + 1) * \ - args.eval_iters - test_iters = args.eval_iters - train_val_test_num_samples = [train_samples, - eval_iters * args.global_batch_size, - test_iters * args.global_batch_size] - print_rank_0(' > datasets target sizes (minimum size):') - print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) - print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) - print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) - - # Build the datasets. - train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( - train_val_test_num_samples) + # Build datasets. + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + build_train_valid_test_datasets_provider) # if dataloading option is not 2 convert to list to allow # same interface for multiple data groups # for validation and testing in option 2 @@ -1039,6 +1051,7 @@ def build_train_valid_test_data_loaders( def build_train_valid_test_data_iterators( build_train_valid_test_datasets_provider): + """Build pretraining data iterators.""" args = get_args() diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 7b73239271..4f10205990 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Pretrain GPT""" @@ -107,7 +107,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): skip_warmup=(not args.mmap_warmup), train_data_prefix=args.train_data_path, valid_data_prefix=args.valid_data_path, - test_data_prefix=args.test_data_path) + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) # Option 2 of data loading using --(train|valid|test)-weighted-split-paths elif args.train_weighted_split_paths: assigned_train_valid_test = [] @@ -132,7 +133,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): train_val_test_num_samples, args.seq_length, args.seed, (not args.mmap_warmup), - train_valid_test=s) + train_valid_test=s, + data_cache_path=args.data_cache_path) assert d is not None, \ f"Got an empty split when trying to create dataset: {paths, weights, splits, name}" eval(f"{s}_ds").append(d) diff --git a/tools/retro/main.py b/tools/retro/main.py index 3cebdc8ab7..f7850087c8 100644 --- a/tools/retro/main.py +++ b/tools/retro/main.py @@ -55,15 +55,40 @@ def add_retro_args(parser): "a separate file.") # GPT args. + group.add_argument('--retro-gpt-seed', type=int, default=1234, + help='Random seed used for python, numpy, ' + 'pytorch, and cuda.') + group.add_argument('--retro-gpt-data-impl', type=str, default='infer', + choices=['lazy', 'cached', 'mmap', 'infer'], + help='Implementation of indexed datasets.') + group.add_argument('--retro-gpt-data-path', nargs='*', required=True, + help='Path to the training dataset. Accepted format:' + '1) a single data path, 2) multiple datasets in the' + 'form: dataset1-weight dataset1-path dataset2-weight ' + 'dataset2-path ... It is used with --split when a ' + 'single dataset used for all three: train, valid ' + 'and test. It is exclusive to the other ' + '--*-data-path args') + group.add_argument('--retro-gpt-split', type=str, default='969,30,1', + help='Comma-separated list of proportions for training,' + ' validation, and test split. For example the split ' + '`90,5,5` will use 90%% of data for training, 5%% for ' + 'validation and 5%% for test.') + group.add_argument('--retro-gpt-mmap-warmup', action='store_true', + help='Warm up mmap files.') + group.add_argument("--retro-gpt-eval-interval", type=int, required=True, + help="GPT evaluation interval.") + group.add_argument("--retro-gpt-eval-iters", type=int, required=True, + help="GPT evaluation iterations.") group.add_argument("--retro-gpt-tokenizer-type", required=True, help="GPT tokenizer type.") group.add_argument("--retro-gpt-vocab-file", help="GPT vocab file.") group.add_argument("--retro-gpt-merge-file", help="GPT merge file.") group.add_argument("--retro-gpt-tokenizer-model", help="GPT tokenizer model file.") - group.add_argument("--retro-gpt-seq-length", type=int, default=2048, + group.add_argument("--retro-gpt-seq-length", type=int, required=True, help="GPT sequence length.") - group.add_argument("--retro-gpt-global-batch-size", type=int, default=2048, + group.add_argument("--retro-gpt-global-batch-size", type=int, required=True, help="GPT global batch size.") group.add_argument("--retro-gpt-chunk-length", type=int, default=64, help="GPT chunk length.") diff --git a/tools/retro/query/chunk_dataset.py b/tools/retro/query/chunk_dataset.py index f9cc4d5120..841788fe80 100644 --- a/tools/retro/query/chunk_dataset.py +++ b/tools/retro/query/chunk_dataset.py @@ -4,15 +4,16 @@ import torch from megatron import get_retro_args, print_rank_0 -from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.data.gpt_dataset import build_train_valid_test_datasets \ + as build_gpt_train_valid_test_datasets from megatron.training import ( - build_train_valid_test_data_loaders, + build_train_valid_test_datasets as build_pretraining_train_valid_test_datasets, update_train_iters, ) from tools.retro.db.utils import get_indexed_dataset_infos from tools.retro.utils import get_num_chunks_per_sample -from .utils import get_query_workdir +from .utils import get_neighbor_dirname, get_query_workdir class ChunkDataset(torch.utils.data.Dataset): @@ -86,14 +87,14 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): print_rank_0('> building train, validation, and test datasets ' 'for GPT ...') - train_ds, valid_ds, test_ds = build_train_valid_test_datasets( - data_prefix=args.data_path, - data_impl=args.data_impl, - splits_string=args.split, + train_ds, valid_ds, test_ds = build_gpt_train_valid_test_datasets( + data_prefix=args.retro_gpt_data_path, + data_impl=args.retro_gpt_data_impl, + splits_string=args.retro_gpt_split, train_valid_test_num_samples=train_val_test_num_samples, seq_length=args.retro_gpt_seq_length, - seed=args.seed, - skip_warmup=(not args.mmap_warmup), + seed=args.retro_gpt_seed, + skip_warmup=(not args.retro_gpt_mmap_warmup), return_doc_ids=args.retro_return_doc_ids) print_rank_0("> finished creating pretrained GPT datasets ...") @@ -115,28 +116,23 @@ def get_chunk_dataset_map(): verify_indexed_dataset_order() # Datasets. - print_rank_0(" > data loader.") - train_data_loader, valid_data_loader, test_data_loader \ - = build_train_valid_test_data_loaders( - train_valid_test_datasets_provider) - - data_loader_map = { - "train" : train_data_loader, - "valid" : valid_data_loader, - "test" : test_data_loader, + print_rank_0(" > datasets.") + train_ds, valid_ds, test_ds = build_pretraining_train_valid_test_datasets( + train_valid_test_datasets_provider) + + sample_dataset_map = { + "train" : train_ds, + "valid" : valid_ds, + "test" : test_ds, } # Info dict. - workdir = get_query_workdir() - dataset_map = { + chunk_dataset_map = { key : { - "neighbor_dir" : os.path.join( - workdir, - os.path.basename(loader.dataset.datasets[0].index_prefix), - ), - "data" : ChunkDataset(loader.dataset, args.retro_gpt_chunk_length), + "neighbor_dir" : get_neighbor_dirname(key, sample_ds), + "data" : ChunkDataset(sample_ds, args.retro_gpt_chunk_length), } - for key, loader in data_loader_map.items() if loader + for key, sample_ds in sample_dataset_map.items() if sample_ds } - return dataset_map + return chunk_dataset_map diff --git a/tools/retro/query/retro_dataset.py b/tools/retro/query/retro_dataset.py index e89a47007a..0879d5d5fc 100644 --- a/tools/retro/query/retro_dataset.py +++ b/tools/retro/query/retro_dataset.py @@ -10,6 +10,7 @@ from tools.retro.external_libs import h5py from .chunk_dataset import get_chunk_dataset_map +from .utils import get_neighbor_dirname class RetroDataset(torch.utils.data.Dataset): @@ -120,11 +121,10 @@ def get_retro_datasets(verify_sizes=True): retro_args.retro_block_size) # Verify dataset prefixes. - sample_prefix = chunk_dataset.sample_dataset.datasets[0].index_prefix - neighbor_prefix = os.path.basename(neighbor_dir) - assert sample_prefix == neighbor_prefix, \ + expected_dir = get_neighbor_dirname(data_key, chunk_dataset.sample_dataset) + assert expected_dir == neighbor_dir, \ "inconsistent dataset source; '%s' vs. '%s'." % \ - (sample_prefix, neighbor_prefix) + (expected_dir, neighbor_dir) # Verify num chunks. n_sample_chunks = len(chunk_dataset) diff --git a/tools/retro/query/utils.py b/tools/retro/query/utils.py index a4ea2a5ca1..f6557abf1f 100644 --- a/tools/retro/query/utils.py +++ b/tools/retro/query/utils.py @@ -1,5 +1,6 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import hashlib import os from megatron import get_retro_args @@ -8,3 +9,9 @@ def get_query_workdir(): args = get_retro_args() return os.path.join(args.retro_workdir, "query") + + +def get_neighbor_dirname(key, dataset): + hashes = ",".join([ d.desc_hash for d in dataset.datasets ]) + hash = hashlib.md5(hashes.encode()).hexdigest() + return os.path.join(get_query_workdir(), os.path.basename(f"{key}_{hash}"))