From 6ab42aad9e0d0edc7e68702a7f3bc8c39d05f538 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 13 Mar 2025 05:32:35 +0000 Subject: [PATCH 1/3] update Signed-off-by: inkcherry --- deepspeed/module_inject/auto_tp.py | 7 +++- deepspeed/module_inject/layers.py | 45 ++++++++++++++++++++++- deepspeed/module_inject/replace_module.py | 6 +-- 3 files changed, 53 insertions(+), 5 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 8d8381ed0428..69e15e02f69b 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -382,7 +382,12 @@ def _replace(self, child, name, conv_linear_layer): if self.conv_linear_layer: return Conv_LinearALlreduce(child, self.mp_group, name=name) elif name == "lm_head" or name == 'embed_out': - return LmHeadLinearAllreduce(child, self.mp_group) + if is_autotp_training_mode(): + # pass + # return child + return LinearLayer(child, self.mp_group, name=name, gather_output=True) + else: + return LmHeadLinearAllreduce(child, self.mp_group) return LinearAllreduce(child, self.mp_group, name=name) else: diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 3c7491e99999..f5a6434a654a 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -109,6 +109,39 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]: dist.all_reduce(grad_output.contiguous(), group=ctx.group) return None, grad_output +class GatherTensor(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + # @staticmethod + # def symbolic(graph, input_): + # """Symbolic function for tracing.""" + # return _gather_along_last_dim(input_) + + @staticmethod + def forward(ctx, group, input_): + """Forward function.""" + # gather along last dim + world_size=dist.get_world_size(group) + if world_size==1: + return + ctx.group=group + ctx.world_size=world_size + + gather_shape = (world_size,) + input_.shape + output =torch.empty(gather_shape, dtype=input_.dtype, device=get_accelerator().current_device_name() ) + dist.all_gather_into_tensor(output, input_.contiguous(), group) + tensor_list = output.chunk(world_size, dim=0) + output = torch.cat(tensor_list, dim=-1).squeeze(0).contiguous() + return output + + @staticmethod + def backward(ctx, grad_output): + #split along last_dim + """Backward function.""" + rank = dist.get_rank(ctx.group) + input_list = torch.chunk(grad_output, ctx.world_size, -1) + grad_output = input_list[rank].contiguous() + return None, grad_output class TensorParallel_Layer(nn.Module, ABC): """ @@ -394,16 +427,20 @@ def uneven_partition(self, params_list): #remove kwargs from partition. class LinearLayer(TensorParallel_Layer): - def __init__(self, module, mp_group=None, skip_partition=False, **kwargs): + def __init__(self, module, mp_group=None, skip_partition=False, gather_output=False, **kwargs): super(LinearLayer, self).__init__(mp_group, **kwargs) self.weight = module.weight self.bias = module.bias + if gather_output: + b=0 if not skip_partition: self._tp_partition([self.weight, self.bias]) self.support_training = True self.config_tp_params(self.weight) if self.bias is not None: self.config_tp_params(self.bias) + self.gather_output=gather_output + def forward(self, input): if getattr(self, 'mp_group', None) is not None: @@ -411,6 +448,10 @@ def forward(self, input): output = torch.matmul(input, self.weight.transpose(-1, -2)) if self.bias is not None: output += self.bias + + if self.gather_output: + output = GatherTensor.apply(self.mp_group,output) + return output @torch.no_grad() @@ -598,6 +639,8 @@ def __init__(self, module, mp_group, **kwargs): def forward(self, input): input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head") input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index]) + input= input[:, :, input_shard_offset:input_shard_offset + input_shard_size] + output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size], self.weight.transpose(-1, -2)) if self.mp_group is not None: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index ed94a5021fee..69629b6d096e 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -335,9 +335,9 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): return new_module def set_lm_head(module): - if is_autotp_training_mode(): - # we need to handle autoTP training mode separately. - return + # if is_autotp_training_mode(): + # # we need to handle autoTP training mode separately. + # return embedding_weight = None for n, p in module.named_parameters(): From 1cf3038fa050229cd80230c19247c0559444715c Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 28 Mar 2025 03:14:21 +0000 Subject: [PATCH 2/3] update Signed-off-by: inkcherry --- deepspeed/module_inject/auto_tp.py | 7 ++++--- deepspeed/module_inject/layers.py | 8 +------- deepspeed/module_inject/replace_module.py | 3 --- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 69e15e02f69b..5829e910b014 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -383,9 +383,10 @@ def _replace(self, child, name, conv_linear_layer): return Conv_LinearALlreduce(child, self.mp_group, name=name) elif name == "lm_head" or name == 'embed_out': if is_autotp_training_mode(): - # pass - # return child - return LinearLayer(child, self.mp_group, name=name, gather_output=True) + return child + + ## gather output column parallel + ## return LinearLayer(child, self.mp_group, name=name, gather_output=True) else: return LmHeadLinearAllreduce(child, self.mp_group) diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index f5a6434a654a..db615c29e46d 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -112,10 +112,6 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]: class GatherTensor(torch.autograd.Function): """Gather the input from model parallel region and concatinate.""" - # @staticmethod - # def symbolic(graph, input_): - # """Symbolic function for tracing.""" - # return _gather_along_last_dim(input_) @staticmethod def forward(ctx, group, input_): @@ -431,8 +427,7 @@ def __init__(self, module, mp_group=None, skip_partition=False, gather_output=Fa super(LinearLayer, self).__init__(mp_group, **kwargs) self.weight = module.weight self.bias = module.bias - if gather_output: - b=0 + if not skip_partition: self._tp_partition([self.weight, self.bias]) self.support_training = True @@ -639,7 +634,6 @@ def __init__(self, module, mp_group, **kwargs): def forward(self, input): input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head") input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index]) - input= input[:, :, input_shard_offset:input_shard_offset + input_shard_size] output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size], self.weight.transpose(-1, -2)) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 69629b6d096e..ef2ae2394152 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -335,9 +335,6 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): return new_module def set_lm_head(module): - # if is_autotp_training_mode(): - # # we need to handle autoTP training mode separately. - # return embedding_weight = None for n, p in module.named_parameters(): From f6fa3845895fdcae36df35ec41d0c297af8ee22d Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 28 Mar 2025 11:15:14 +0800 Subject: [PATCH 3/3] format Signed-off-by: inkcherry --- deepspeed/module_inject/auto_tp.py | 2 +- deepspeed/module_inject/layers.py | 32 +++++++++++++++--------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 5829e910b014..3ddd010a97d3 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -384,7 +384,7 @@ def _replace(self, child, name, conv_linear_layer): elif name == "lm_head" or name == 'embed_out': if is_autotp_training_mode(): return child - + ## gather output column parallel ## return LinearLayer(child, self.mp_group, name=name, gather_output=True) else: diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index db615c29e46d..2e494f82cfa3 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -109,23 +109,23 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]: dist.all_reduce(grad_output.contiguous(), group=ctx.group) return None, grad_output + class GatherTensor(torch.autograd.Function): """Gather the input from model parallel region and concatinate.""" - @staticmethod def forward(ctx, group, input_): """Forward function.""" # gather along last dim - world_size=dist.get_world_size(group) - if world_size==1: - return - ctx.group=group - ctx.world_size=world_size - - gather_shape = (world_size,) + input_.shape - output =torch.empty(gather_shape, dtype=input_.dtype, device=get_accelerator().current_device_name() ) - dist.all_gather_into_tensor(output, input_.contiguous(), group) + world_size = dist.get_world_size(group) + if world_size == 1: + return + ctx.group = group + ctx.world_size = world_size + + gather_shape = (world_size, ) + input_.shape + output = torch.empty(gather_shape, dtype=input_.dtype, device=get_accelerator().current_device_name()) + dist.all_gather_into_tensor(output, input_.contiguous(), group) tensor_list = output.chunk(world_size, dim=0) output = torch.cat(tensor_list, dim=-1).squeeze(0).contiguous() return output @@ -139,6 +139,7 @@ def backward(ctx, grad_output): grad_output = input_list[rank].contiguous() return None, grad_output + class TensorParallel_Layer(nn.Module, ABC): """ A base class for model layers with tensor parallelism support. @@ -434,8 +435,7 @@ def __init__(self, module, mp_group=None, skip_partition=False, gather_output=Fa self.config_tp_params(self.weight) if self.bias is not None: self.config_tp_params(self.bias) - self.gather_output=gather_output - + self.gather_output = gather_output def forward(self, input): if getattr(self, 'mp_group', None) is not None: @@ -443,10 +443,10 @@ def forward(self, input): output = torch.matmul(input, self.weight.transpose(-1, -2)) if self.bias is not None: output += self.bias - + if self.gather_output: - output = GatherTensor.apply(self.mp_group,output) - + output = GatherTensor.apply(self.mp_group, output) + return output @torch.no_grad() @@ -634,7 +634,7 @@ def __init__(self, module, mp_group, **kwargs): def forward(self, input): input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head") input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index]) - + output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size], self.weight.transpose(-1, -2)) if self.mp_group is not None: