Skip to content
6 changes: 6 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ class PytorchEngineConfig:
session_len (int): Max session length. Default None.
max_batch_size (int): Max batch size. If it is not specified,
the engine will automatically set it according to the device
attn_tp_size (int): tp size for attention, only works for dp>1
mlp_tp_size (int): tp size for mlp, only works for dp>1
moe_tp_size (int): tp size for moe, only works for dp>1
cache_max_entry_count (float): the percentage of gpu memory occupied
by the k/v cache. For lmdeploy versions greater than `v0.2.1`,
it defaults to 0.8, signifying the percentage of FREE GPU memory
Expand Down Expand Up @@ -350,6 +353,9 @@ class PytorchEngineConfig:
ep: int = 1
session_len: int = None
max_batch_size: int = None
attn_tp_size: int = None
mlp_tp_size: int = None
moe_tp_size: int = None
cache_max_entry_count: float = 0.8
prefill_interval: int = 16
block_size: int = 64
Expand Down
7 changes: 6 additions & 1 deletion lmdeploy/pytorch/backends/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ def update_weights(self,
return qweight, scales, qzeros, bias

@abstractmethod
def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False):
def forward(self,
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
raise NotImplementedError

Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/blockedf8_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional

import torch
import torch.distributed as dist


class LinearBlockedF8Impl(ABC):
Expand All @@ -19,6 +20,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
self.flash_attention_fwd = flash_attention_fwd

# for alibi attention
world_size, rank = get_tp_world_rank()
world_size, rank = get_tp_world_rank('attn')
self.alibi_head_offset = self.num_heads * rank
self.alibi_num_heads = self.num_heads * world_size
self.block_sparse_size = block_sparse_size
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/cuda/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ def forward(self,
scales: torch.Tensor,
qzeros: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
out_features = scales.size(1)
out = wq_gemm_forward(x, qweight, qzeros, scales, self.w_bit, self.group_size, bias, out_features)
if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
20 changes: 6 additions & 14 deletions lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@
logger = get_logger('lmdeploy')


def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]):
"""Reduce scatter."""
outs = out.split(tp_sizes, -2)
out = outs[rank]
outs = list(outs)
dist.reduce_scatter(out, outs)
return out


class TritonLinearBlockedF8Impl(LinearBlockedF8Impl):
"""Triton linear blocked f8 implementation."""

Expand All @@ -37,6 +28,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand All @@ -52,7 +44,7 @@ def forward(self,

if all_reduce:
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
else:
dist.all_reduce(out)
return out
Expand Down Expand Up @@ -117,6 +109,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand All @@ -128,12 +121,11 @@ def forward(self,
out = out[:x.size(0)]
if bias is not None:
out += bias
out = out.unflatten(0, x_shape[:-1])

if all_reduce:
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
else:
dist.all_reduce(out)

out = out.unflatten(0, x_shape[:-1])
dist.all_reduce(out, group=group)
return out
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def update_inputs(self, inputs):
meta = self.get_meta()
padding_batch_size = meta.padding_batch_size
tp_size = self._get_capture_tokens(padding_batch_size)
dp_meta.tp_sizes = [tp_size] * len(dp_meta.tp_sizes)
dp_meta.sync_tp_size(tp_size)
return inputs

def get_capture_batch_sizes(self) -> List[int]:
Expand Down
44 changes: 44 additions & 0 deletions lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,45 @@ def ep_expert_list(self, world_size: int, rank: int):
else:
return super().ep_expert_list(world_size=world_size, rank=rank)

def _split_inputs_by_attn_tp(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
):
"""Split input by attn tp."""
dist_ctx = get_dist_manager().current_context()
attn_tp = dist_ctx.dist_config.attn_tp
attn_rank = dist_ctx.attn_tp_group.rank
num_states = hidden_states.size(0)

if attn_tp == 1 or attn_tp > num_states:
return hidden_states, topk_weights, topk_ids, None

# split size
base = num_states // attn_tp
remain = num_states % attn_tp
split_size = [base + 1] * remain + [base] * (attn_tp - remain)

# split inputs
hidden_states = torch.split(hidden_states, split_size, dim=0)[attn_rank]
topk_weights = torch.split(topk_weights, split_size, dim=0)[attn_rank]
topk_ids = torch.split(topk_ids, split_size, dim=0)[attn_rank]

return hidden_states, topk_weights, topk_ids, split_size

def _gather_outputs_by_attn_tp(self, out_states: torch.Tensor, split_size: List[int]):
"""Gather output by attn tp."""
if split_size is None:
return out_states

dist_ctx = get_dist_manager().current_context()
gpu_group = dist_ctx.attn_tp_group.gpu_group
new_out_states = out_states.new_empty((sum(split_size), out_states.shape[1]))
new_out_states_list = list(new_out_states.split(split_size, dim=0))
dist.all_gather(new_out_states_list, out_states, group=gpu_group)
return new_out_states

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
Expand All @@ -633,12 +672,17 @@ def forward(self,
act_func: Callable = None,
**kwargs):
"""forward."""
hidden_states, topk_weights, topk_ids, split_size = self._split_inputs_by_attn_tp(
hidden_states, topk_weights, topk_ids)

topk_weights = self.do_renormalize(topk_weights)
step_ctx = get_step_ctx_manager().current_context()
low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm
moe = self.fusedmoe_build(low_latency_mode)
out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights,
down_scale, expert_list)

out_states = self._gather_outputs_by_attn_tp(out_states, split_size)
return out_states

def do_renormalize(self, topk_weights):
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/cuda/qmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def forward(self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
if isinstance(x, torch.Tensor):
input_quant, input_scale = per_token_quant_int8(x, 1e-7, quant_dtype=self.quant_dtype)
Expand All @@ -79,7 +80,7 @@ def forward(self,
bias=bias)

if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/default/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def forward(self,
scales: torch.Tensor,
qzeros: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
out_shape = x.shape[:-1] + (self.out_features, )
input_dtype = x.dtype
Expand All @@ -77,7 +78,7 @@ def forward(self,
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
22 changes: 5 additions & 17 deletions lmdeploy/pytorch/backends/default/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,12 @@
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F

import lmdeploy.pytorch.distributed as dist

from ..linear import LinearBuilder, LinearImpl


def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]):
"""Reduce scatter."""
out = out.transpose(0, -2)
if not out.is_contiguous():
out = out.contiguous()
outs = out.split(tp_sizes, 0)
out = outs[rank]
outs = list(outs)
dist.reduce_scatter(out, outs)
out = out.transpose(0, -2)
return out


class DefaultLinearImpl(LinearImpl):
"""Linear implementation api."""

Expand All @@ -30,15 +16,17 @@ def forward(self,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: dist.ProcessGroup = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
out = F.linear(x, weight, bias)
if all_reduce:
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
from lmdeploy.pytorch.distributed import reduce_scatter_by_tp_sizes
out = reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
else:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/backends/dlinfer/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def forward(self,
scales: torch.Tensor,
qzeros: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
out = awq_linear(x, qweight, scales, qzeros, bias, all_reduce, self.group_size)
return out
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/dlinfer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import List, Optional

import torch
import torch.distributed as dist

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.kernels.dlinfer import linear

from ..linear import LinearBuilder, LinearImpl
Expand Down Expand Up @@ -32,12 +32,13 @@ def forward(self,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: dist.ProcessGroup = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
out = linear(x, weight, bias, False)
if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/dlinfer/qmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def forward(self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
if isinstance(x, torch.Tensor):
input_quant, input_scale = dynamic_quant(x, self.quant_dtype)
Expand All @@ -46,7 +47,7 @@ def forward(self,

out = linear_w8a8(input_quant, weight, input_scale, scale, self.out_dtype, self.quant_dtype, bias)
if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional

import torch
import torch.distributed as dist


class LinearImpl(ABC):
Expand All @@ -18,6 +19,7 @@ def forward(self,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: dist.ProcessGroup = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/backends/qmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def forward(self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/check_env/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def check(self):
if self.device_type == 'cuda' and not is_dlblas_installed():
self.log_and_exit(mod_name='Dist',
message='ep>1 requires install dlblas(https://github.com/DeepLink-org/dlBLAS).')
if self.dp % self.ep != 0:
if self.ep % self.dp != 0:
self.log_and_exit(mod_name='Dist',
message=f'ep>1 requires dp % ep == 0. Get dp={self.dp} and ep={self.ep}.')
message=f'ep>1 requires ep % dp == 0. Get dp={self.dp} and ep={self.ep}.')
elif self.dist_config.enable_eplb:
self.log_and_exit(mod_name='Dist', message=f'Enable eplb requires ep > 1. Get ep={self.ep}.')

Expand Down
Loading