diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 891885c37..b77a49d1d 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -128,6 +128,7 @@ use_fp32_norm = False model = dict( checkpoint=False, + # checkpoint_tp_no_comm=True, # whether use TP recomputation communication optimization num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, diff --git a/configs/7B_sft.py b/configs/7B_sft.py index eba87bcd9..f3e28221a 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -141,6 +141,7 @@ use_fp32_norm = False model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + # checkpoint_tp_no_comm=True, # whether use TP recomputation communication optimization num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 2c8089727..0b0ab604b 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -161,6 +161,7 @@ def __init__(self): self.virtual_pipeline_parallel_rank = None self._expert_parallel_group_names = [] self.is_evaluating = False + self.recompute_forward_no_comm = False @property def config(self): diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py index 2dfc8bd28..6b5f46652 100644 --- a/internlm/core/parallel/comm/tensor.py +++ b/internlm/core/parallel/comm/tensor.py @@ -66,7 +66,9 @@ def input_hook( @abstractmethod def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False + self, + grad_output: torch.Tensor, + async_op: bool = False, ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ communication for grad_output when backward. @@ -81,7 +83,11 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T pass @abstractmethod - def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + def output_hook( + self, + output: torch.Tensor, + async_op: bool = False, + ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ communication for output when forward. """ @@ -93,13 +99,14 @@ class TensorParallelCommunicator(TPCommunicator): tensor parallel communicator for linear """ - def __init__(self, process_group: dist.ProcessGroup, role: LinearRole) -> None: + def __init__(self, process_group: dist.ProcessGroup, role: LinearRole, last_block_layer=False) -> None: assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}" self._process_group = process_group self._role = role self._save_total_input = False + self.last_block_layer = last_block_layer def save_total_input(self) -> bool: return self._save_total_input @@ -116,7 +123,9 @@ def input_hook( return _input, DUMMY_HANDLE_CONST def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + self, + grad_output: torch.Tensor, + async_op: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ tensor parallel should do nothing for grad_output. @@ -132,11 +141,19 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op) - def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + def output_hook( + self, + output: torch.Tensor, + async_op: bool = False, + ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all reduce output only for row parallel linear when forward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if ( + (self.last_block_layer and gpc.recompute_forward_no_comm) + or dist.get_world_size(self._process_group) <= 1 + or self._role == LinearRole.COLUMN + ): return output, DUMMY_HANDLE_CONST return all_reduce_raw(output, process_group=self._process_group, async_op=async_op) @@ -148,7 +165,11 @@ class SequenceParallelCommunicator(TPCommunicator): """ def __init__( - self, process_group: dist.ProcessGroup, role: LinearRole, save_total_input_as_activation: bool = False + self, + process_group: dist.ProcessGroup, + role: LinearRole, + save_total_input_as_activation: bool = False, + last_block_layer=False, ) -> None: assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}" @@ -156,6 +177,8 @@ def __init__( self._role = role self._save_total_input = save_total_input_as_activation + self.last_block_layer = last_block_layer + self.no_communication = False def save_total_input(self) -> bool: return self._save_total_input @@ -182,12 +205,19 @@ def input_hook( return all_gather_raw(_input, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False + self, + grad_output: torch.Tensor, + async_op: bool = False, ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather grad_output only for row parallel linear when backward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if ( + (self.last_block_layer and self.no_communication) + or dist.get_world_size(self._process_group) <= 1 + or self._role == LinearRole.COLUMN + ): + self.no_communication = False return grad_output, DUMMY_HANDLE_CONST return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) @@ -203,11 +233,20 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T grad_input, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM ) - def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + def output_hook( + self, + output: torch.Tensor, + async_op: bool = False, + ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ reduce scatter output only for row parallel linear when forward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + self.no_communication = gpc.recompute_forward_no_comm + if ( + (self.last_block_layer and self.no_communication) + or dist.get_world_size(self._process_group) <= 1 + or self._role == LinearRole.COLUMN + ): return output, DUMMY_HANDLE_CONST return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM) @@ -225,7 +264,9 @@ def __init__(self, parallel_mode: ParallelMode, retain_out_sharded: bool = True) self._retain_out_sharded = retain_out_sharded def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + self, + grad_output: torch.Tensor, + async_op: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ split grad_output if retain_out_sharded is False. @@ -236,7 +277,9 @@ def grad_output_hook( return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST def output_hook( - self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + self, + output: torch.Tensor, + async_op: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather output for head layer if retain_out_sharded is False. @@ -266,7 +309,9 @@ def __init__( # rewrite grad_output communication hook def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + self, + grad_output: torch.Tensor, + async_op: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ split grad_output if retain_out_sharded is False. @@ -278,7 +323,9 @@ def grad_output_hook( # rewrite ouput communication hook def output_hook( - self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + self, + output: torch.Tensor, + async_op: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather output for head layer if retain_out_sharded is False. diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index f9df1b0b8..b7335939c 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -296,10 +296,13 @@ def args_sanity_check(): ] if "checkpoint" in model: + if "checkpoint_tp_no_comm" not in model: + gpc.config.model._add_item("checkpoint_tp_no_comm", True) if model.checkpoint is True: model.checkpoint = 1 elif model.checkpoint is False: model.checkpoint = 0 + model.checkpoint_tp_no_comm = False else: assert ( model.checkpoint >= 0 and model.checkpoint <= 1 @@ -419,6 +422,14 @@ def args_sanity_check(): gpc.config.parallel["pipeline"].get("interleaved_overlap", False) is True ), "only support interleaved pipeline scheduler with overlap" + # when not use tp or sp, checkpoint_tp_no_comm should always be False + if ( + gpc.config.parallel["tensor"]["mode"] == "isp" + or gpc.config.parallel["tensor"]["size"] <= 1 + or gpc.config.model_type not in ["INTERNLM", "INTERNLM2_PUBLIC"] + ) and getattr(gpc.config.model, "checkpoint_tp_no_comm", False): + gpc.config.model.checkpoint_tp_no_comm = False + # monitoring default config monitor_default_config = { "alert_address": None, # compatible with old alert config diff --git a/internlm/model/builder.py b/internlm/model/builder.py index b50a1fdbb..ba1ec08fa 100644 --- a/internlm/model/builder.py +++ b/internlm/model/builder.py @@ -31,6 +31,9 @@ def create_model(model_type) -> Union[nn.Module, List[nn.Module]]: kwargs["checkpoint"] = float(kwargs.get("checkpoint", False)) kwargs["device"] = get_current_device() + if "checkpoint_tp_no_comm" in kwargs: + kwargs.pop("checkpoint_tp_no_comm") + model_buidler = model_initializer.get_module(module_name=model_type) if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE): diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 5994e15d5..3cf9c14af 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -21,9 +21,11 @@ convert_attn_kwargs_to_args, internlm1_mha_pre_load_convert, internlm1_mha_save_convert, + padding_residual, ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_sequence_parallel logger = get_logger(__file__) @@ -213,6 +215,10 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): hidden_states = self.mlp(hidden_states) + # pad residual + if gpc.recompute_forward_no_comm and is_using_sequence_parallel(): + residual = padding_residual(residual) + return hidden_states + residual diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index c3b894120..303125c2e 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -21,9 +21,11 @@ from internlm.model.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, + padding_residual, ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_sequence_parallel logger = get_logger(__file__) @@ -255,8 +257,13 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) + hidden_states = self.feed_forward(hidden_states) + # pad residual + if gpc.recompute_forward_no_comm and is_using_sequence_parallel(): + residual = padding_residual(residual) + return hidden_states + residual else: assert residual is None diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 820df33be..43065e764 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -346,7 +346,6 @@ def __init__( def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." - return fused_dense_func( input, self.weight, @@ -589,7 +588,7 @@ def new_linear( dtype, ) elif split_mode == "row": - return RowParallelLinear( + linear = RowParallelLinear( in_features, out_features, bias, @@ -597,6 +596,9 @@ def new_linear( device, dtype, ) + if name == "w2": + setattr(linear, "last_block_layer", True) + return linear else: err_msg = ( f"Parallel strategies for linear is unsupported, which is named as {name}.\n" diff --git a/internlm/model/utils.py b/internlm/model/utils.py index e4a40dabb..a9fcb892c 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,7 +1,10 @@ from typing import Any, Dict, List + +import torch + +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.modules.mha import MHA def internlm1_mha_pre_load_convert( @@ -54,6 +57,23 @@ def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]: return kwargs + +def padding_residual(residual): + requires_grad = residual.requires_grad + _GATHER_DIM = 1 + total_size = gpc.get_world_size(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] + zero_padding_tensor = torch.zeros( + (*residual.shape[:_GATHER_DIM], total_size, *residual.shape[_GATHER_DIM + 1 :]), + dtype=residual.dtype, + device=residual.device, + ) + start_idx = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] + end_idx = start_idx + residual.shape[_GATHER_DIM] + zero_padding_tensor[:, start_idx:end_idx, :] = residual + residual = zero_padding_tensor.requires_grad_(requires_grad) + + return residual + def convert_hf_config(config): gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = config.vocab_size gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = config.hidden_size @@ -64,3 +84,4 @@ def convert_hf_config(config): # For models that use GQA if hasattr(config, "num_key_value_heads"): gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = config.num_key_value_heads + diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py index 870557714..b46e91fb6 100644 --- a/internlm/solver/activation_checkpoint.py +++ b/internlm/solver/activation_checkpoint.py @@ -2,11 +2,14 @@ # -*- encoding: utf-8 -*- import weakref +from contextlib import contextmanager import torch from torch.utils.checkpoint import check_backward_validity, detach_variable from internlm.accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc from internlm.core.context.random import ( get_current_mode, get_states, @@ -14,6 +17,8 @@ set_seed_states, sync_states, ) +from internlm.core.parallel.comm.tensor import _GATHER_DIM, all_gather_raw +from internlm.utils.parallel import is_using_sequence_parallel from ..utils.common import get_current_device @@ -37,6 +42,29 @@ def copy_to_device(obj, device): return obj +@contextmanager +def recompute_forward_context(args, no_communication): + handle = None + try: + # Set True when entering the context + if no_communication: + gpc.recompute_forward_no_comm = True + if is_using_sequence_parallel(): + # overlap all_gather + grad_output = args[0] + grad_output, handle = all_gather_raw( + grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM + ) + yield + finally: + # Set False when exiting the context + gpc.recompute_forward_no_comm = False + + if handle: + handle.wait() + args[0] = grad_output + + class CheckpointFunction(torch.autograd.Function): """ Checkpoint Function @@ -122,13 +150,20 @@ def backward(ctx, *args): # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] + + # when checkpoint_tp_no_comm==True, we use TP recomputation communication optimization + no_communication = getattr(gpc.config.model, "checkpoint_tp_no_comm", False) + detached_inputs = detach_variable(tuple(inputs)) - if ctx.had_autocast_in_fwd: - with torch.enable_grad(), internlm_accelerator.amp.autocast(): - outputs = ctx.run_function(*detached_inputs) - else: - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) + + args = list(args) + with recompute_forward_context(args, no_communication): + if ctx.had_autocast_in_fwd: + with torch.enable_grad(), internlm_accelerator.amp.autocast(): + outputs = ctx.run_function(*detached_inputs) + else: + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 47ab70ce2..31ebc1505 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -276,6 +276,15 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): ) _head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) _embedding_communicator = EmbeddingTensorParallelCommunicator(ParallelMode.TENSOR) + + # for tp recompute communication optimization, sign last block layer + for row_parallel_linear in _submodule_filter(model, RowParallelLinear): + if getattr(row_parallel_linear, "last_block_layer", False): + row_parallel_linear.register_communicator( + TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW, last_block_layer=True + ) + ) # sequence parallel if gpc.config.parallel.tensor.mode in (TensorParallelMode.msp.name, TensorParallelMode.fsp.name): save_total_input_as_activation = gpc.config.parallel.tensor.mode == TensorParallelMode.msp.name @@ -295,6 +304,18 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): ) ) + # for tp recompute communication optimization, sign last block layer + for row_parallel_linear in _submodule_filter(model, RowParallelLinear): + if getattr(row_parallel_linear, "last_block_layer", False): + row_parallel_linear.register_communicator( + SequenceParallelCommunicator( + gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.ROW, + save_total_input_as_activation=save_total_input_as_activation, + last_block_layer=True, + ) + ) + _head_communicator = HeadSequenceParallelCommunicator( ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation )