diff --git a/configs/_base_/models/internlm2_1B.py b/configs/_base_/models/internlm2_1B.py index 5d050da92..833251b60 100644 --- a/configs/_base_/models/internlm2_1B.py +++ b/configs/_base_/models/internlm2_1B.py @@ -63,6 +63,7 @@ 1. size: int, the size of pipeline parallel. 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, defaults to False. + 4. batch_p2p_comm: bool, enable/disable batch p2p communication, defaults to False. weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. diff --git a/internlm/core/scheduler/comm/p2p.py b/internlm/core/scheduler/comm/p2p.py index 54fb587c0..82bc7f4f2 100644 --- a/internlm/core/scheduler/comm/p2p.py +++ b/internlm/core/scheduler/comm/p2p.py @@ -44,6 +44,15 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> return tensor_chunk_shape, chunk_tensor +def _p2p_func(_comm_op, _obj, _comm_rank): + if getattr(gpc.config.parallel.pipeline, "batch_p2p_comm", False) is True: + op_or_handle = dist.P2POp(_comm_op, _obj, _comm_rank) + else: + op_or_handle = _comm_op(_obj, _comm_rank) + + return op_or_handle + + def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): if isinstance(recv_shapes, torch.Size): recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) @@ -78,12 +87,10 @@ def process_object_to_send(object_send, scatter_gather_tensors): def filling_ops_queue(obj, comm_op, comm_rank, ops_queue): if isinstance(obj, torch.Tensor): - op_to_add = dist.P2POp(comm_op, obj, comm_rank) - ops_queue.append(op_to_add) + ops_queue.append(_p2p_func(comm_op, obj, comm_rank)) else: for tensor_to_comm in obj: - op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank) - ops_queue.append(op_to_add) + ops_queue.append(_p2p_func(comm_op, tensor_to_comm, comm_rank)) def _communicate( @@ -156,23 +163,42 @@ def _communicate( object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors) ops = [] - if object_send_prev is not None: - filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) - if tensor_recv_prev is not None: - filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + if gpc.get_local_rank(ParallelMode.PIPELINE) % 2 == 0: + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) - if tensor_recv_next is not None: - filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) + + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + else: + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) + + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) - if object_send_next is not None: - filling_ops_queue(object_send_next, dist.isend, next_rank, ops) if len(ops) > 0: - reqs = dist.batch_isend_irecv(ops) - for req in reqs: - req.wait() - # To protect against race condition when using batch_isend_irecv(). - internlm_accelerator.synchronize() + if getattr(gpc.config.parallel.pipeline, "batch_p2p_comm", False) is True: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + internlm_accelerator.synchronize() + else: + for req in ops: + req.wait() if recv_prev and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): @@ -265,29 +291,47 @@ def _communicate_async( object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors) ops = [] - if object_send_prev is not None: - filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) - if tensor_recv_prev is not None: - filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + if gpc.get_local_rank(ParallelMode.PIPELINE) % 2 == 0: + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) - if tensor_recv_next is not None: - filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) - if object_send_next is not None: - filling_ops_queue(object_send_next, dist.isend, next_rank, ops) + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) - if len(ops) > 0: + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + else: + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) + + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) + + if len(ops) > 0 and getattr(gpc.config.parallel.pipeline, "batch_p2p_comm", False) is True: reqs = dist.batch_isend_irecv(ops) # return and do other things yield if len(ops) > 0: - for req in reqs: # pylint: disable=E0601 - req.wait() - # To protect against race condition when using batch_isend_irecv(). - internlm_accelerator.synchronize() + if getattr(gpc.config.parallel.pipeline, "batch_p2p_comm", False) is True: + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + internlm_accelerator.synchronize() + else: + for req in ops: + req.wait() if recv_prev and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): diff --git a/internlm/core/scheduler/pipeline_scheduler_1f1b.py b/internlm/core/scheduler/pipeline_scheduler_1f1b.py index 289bc37d3..6f46d819e 100644 --- a/internlm/core/scheduler/pipeline_scheduler_1f1b.py +++ b/internlm/core/scheduler/pipeline_scheduler_1f1b.py @@ -201,7 +201,7 @@ def _call_engine(engine, data): # pylint: disable=W0237 def load_batch(self, engine, data_iter): # Pipeline schedule just puts data in memory, - batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False) + batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=True) # Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed, # because internlm's current train dataset is packed, even using dummy data. @@ -313,17 +313,18 @@ def _forward_step( accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced - moe_loss = ( - sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606 - if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1 - else torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype")) - ) - # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce - if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: - dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) - moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR)) - moe_loss /= self.num_microbatches - accum_moe_loss.add_(moe_loss.detach()) + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff + + # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, + # so we need to do allreduce + if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: + dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) + moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR)) + moe_loss /= self.num_microbatches + accum_moe_loss.add_(moe_loss.detach()) + else: + moe_loss = None return output_obj, moe_loss @@ -417,7 +418,11 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True) if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) else None ) - accum_moe_loss = torch.zeros(1, device=get_current_device()) + + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + accum_moe_loss = torch.zeros(1, device=get_current_device()) + else: + accum_moe_loss = None # Used for tensor meta information communication forward_recv_shapes = self.tensor_shape @@ -460,8 +465,8 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True) if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - if accum_loss is not None: - accum_loss += accum_moe_loss + if accum_loss is not None: + accum_loss += accum_moe_loss return output, label, accum_loss, accum_moe_loss @@ -518,7 +523,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) else None ) - accum_moe_loss = torch.zeros(1, device=get_current_device()) + + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + accum_moe_loss = torch.zeros(1, device=get_current_device()) + else: + accum_moe_loss = None # Used for tensor meta information communication forward_recv_shapes = self.tensor_shape @@ -664,8 +673,8 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - if accum_loss is not None: - accum_loss += accum_moe_loss + if accum_loss is not None: + accum_loss += accum_moe_loss return output, label, accum_loss, accum_moe_loss @@ -780,6 +789,7 @@ def __init__( self._output_obj_grads = [[] for _ in range(num_chunks)] self._moe_losses = [[] for _ in range(num_chunks)] + self._preload_micro_data = [None for _ in range(self.num_microbatches)] self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)] self._output_obj_shapes = [None for _ in range(num_chunks)] self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)] @@ -803,26 +813,37 @@ def _clear_state(self) -> None: self._output_obj_grads = [[] for _ in range(self._num_chunks)] self._moe_losses = [[] for _ in range(self._num_chunks)] + self._preload_micro_data = [None for _ in range(self.num_microbatches)] self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)] self._output_obj_shapes = [None for _ in range(self._num_chunks)] self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(self._num_chunks)] def load_batch(self, engine, data_iter): super().load_batch(engine, data_iter) + + for mbs in range(self.num_microbatches): + micro_batch_data, micro_batch_label = self._load_micro_batch( + data=self.batch_data, + label=self.batch_label, + offset=mbs * self.bsz_stride, + bsz_stride=self.bsz_stride, + ) + + if self.data_process_func: + micro_batch_data, micro_batch_label = self.data_process_func(micro_batch_data, micro_batch_label) + + micro_batch_data["label"] = micro_batch_label + self._preload_micro_data[mbs] = micro_batch_data + # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset self.microbatch_offset = [0 for _ in range(self._num_chunks)] def load_micro_batch(self, model_chunk_id): - micro_batch_data, micro_batch_label = self._load_micro_batch( - data=self.batch_data, - label=self.batch_label, - offset=self.microbatch_offset[model_chunk_id], - bsz_stride=self.bsz_stride, - ) - if self.data_process_func: - micro_batch_data, micro_batch_label = self.data_process_func(micro_batch_data, micro_batch_label) - micro_batch_data["label"] = micro_batch_label - self.microbatch_offset[model_chunk_id] += self.bsz_stride + offset = self.microbatch_offset[model_chunk_id] + assert self._preload_micro_data[offset] is not None, "preload micro batch data is None" + + micro_batch_data = self._preload_micro_data[offset] + self.microbatch_offset[model_chunk_id] += 1 result = move_to_device(micro_batch_data) return result @@ -876,18 +897,19 @@ def _forward_step(self, engine, chunk_id, input_obj=None): self._accum_loss.add_(loss_reduced.detach()) output_obj = loss_reduced - moe_loss = ( - sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606 - if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1 - else torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype")) - ) - # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce - if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: - dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR)) - moe_loss /= self.num_microbatches + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff - if self._accum_moe_loss is not None: - self._accum_moe_loss.add_(moe_loss.detach()) + # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, + # so we need to do allreduce + if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: + dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR)) + moe_loss /= self.num_microbatches + + if self._accum_moe_loss is not None: + self._accum_moe_loss.add_(moe_loss.detach()) + else: + moe_loss = None self._output_objs[chunk_id].append(output_obj) self._moe_losses[chunk_id].append(moe_loss) @@ -1398,7 +1420,9 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): self._accum_loss = torch.zeros(1, device=get_current_device()) - self._accum_moe_loss = torch.zeros(1, device=get_current_device()) + + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + self._accum_moe_loss = torch.zeros(1, device=get_current_device()) if return_output_label: self._return_tensors = [] @@ -1413,13 +1437,15 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo else: output, label = (None, None) + accum_loss = self._accum_loss + accum_moe_loss = self._accum_moe_loss + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - accum_moe_loss = self._accum_moe_loss + accum_moe_loss = self._accum_moe_loss - accum_loss = self._accum_loss - if accum_loss is not None: - accum_loss += self._accum_moe_loss + if accum_loss is not None: + accum_loss += self._accum_moe_loss self._clear_state() diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e1cb2f0d2..6cd61b8bd 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -97,9 +97,6 @@ def args_sanity_check(): if "pipeline" not in gpc.config.parallel: gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False, mode="1F1B")) - if isinstance(gpc.config.parallel.pipeline, dict) and "mode" not in gpc.config.parallel.pipeline: - gpc.config.parallel.pipeline._add_item("mode", "1F1B") - if "tensor" not in gpc.config.parallel: gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name)) @@ -118,9 +115,16 @@ def args_sanity_check(): if isinstance(gpc.config.parallel.pipeline, int): pp = gpc.config.parallel.pipeline + gpc.config.parallel._add_item("pipeline", dict(size=pp, interleaved_overlap=False)) else: pp = gpc.config.parallel.pipeline.size + if isinstance(gpc.config.parallel.pipeline, dict) and "mode" not in gpc.config.parallel.pipeline: + gpc.config.parallel.pipeline._add_item("mode", "1F1B") + + if "batch_p2p_comm" not in gpc.config.parallel.pipeline: + gpc.config.parallel.pipeline["batch_p2p_comm"] = False + if isinstance(gpc.config.parallel.pipeline, dict): gpc.config.parallel.pipeline["mode"] = gpc.config.parallel.pipeline["mode"].upper() assert gpc.config.parallel.pipeline["mode"] in [ diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 784a5305a..e5847cc2b 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -538,6 +538,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato if ( zero_cfg.overlap_sync_grad and gpc.is_using_parallel_mode(ParallelMode.PIPELINE) + and getattr(gpc.config.parallel.pipeline, "batch_p2p_comm", False) is True and gpc.is_pipeline_first_stage() is False ): # When pipeline parallelism is enabled, we prefer to only enable optimizer diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 967398e17..3aa4d6149 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -21,7 +21,12 @@ initialize_parallel_communicator, load_new_batch, ) -from internlm.utils.common import BatchSkipper, launch_time +from internlm.utils.common import ( + BatchSkipper, + check_cuda_env, + enable_pytorch_expandable_segments, + launch_time, +) from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.megatron_timers import megatron_timer as timer @@ -33,15 +38,15 @@ # dp_size = 4 BASELINE_LOSS_LIST = [ 12.362918853759766, - 12.404379844665527, - 12.348219871520996, - 12.194982528686523, - 11.80469036102295, - 11.573806762695312, - 10.045475006103516, - 9.660882949829102, - 9.172087669372559, - 4.799427032470703, + 12.404375076293945, + 12.348180770874023, + 12.1947021484375, + 11.804483413696289, + 11.573527336120605, + 10.04533576965332, + 9.66073989868164, + 9.172025680541992, + 4.798973560333252 ] @@ -71,7 +76,7 @@ def train( config.data.total_steps = 50000 config.data.fixed_random_dataset_seqlen = False config.data.micro_num = 4 - config.data.micro_bsz = 2 + config.data.micro_bsz = 1 config.lr_scheduler.total_steps = config.data.total_steps config.model_type = model_type config.ckpt.load_ckpt_folder = None @@ -167,6 +172,12 @@ def train( dist.broadcast_object_list(objs, src=0) current_time = objs[0] + # check cuda env + check_cuda_env() + + # set torch expandable_segments + enable_pytorch_expandable_segments() + # initialize model model = initialize_model() @@ -472,15 +483,15 @@ def test_training_with_isp(): CONFIG_FILE_PATH = "./configs/7B_isp_sft.py" BASELINE_LOSS_LIST = [ 12.225811004638672, - 12.103824615478516, - 12.223844528198242, - 11.87704849243164, - 11.651590347290039, - 11.629219055175781, - 10.242591857910156, - 9.768388748168945, - 9.330610275268555, - 5.505439758300781, + 12.10380744934082, + 12.223655700683594, + 11.877079963684082, + 11.651113510131836, + 11.629385948181152, + 10.242776870727539, + 9.768218040466309, + 9.330422401428223, + 5.505432605743408 ] # model training