Skip to content

enable multiple ssm groups duplication #19924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def forward_native(
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)

if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
Expand Down Expand Up @@ -152,12 +151,13 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
if ngroups % tp_size == 0:
return 0

# for n_groups == 1, this is exactly tp_size - n_groups
# for n_groups == 1 or tp_size % ngroups == 0,
# this is exactly tp_size - n_groups
return tp_size - ngroups


def mamba_v2_sharded_weight_loader(
shard_spec: list[tuple[int, int, float]],
shard_spec: list[tuple[int, int, int, bool]],
tp_size: int,
tp_rank: int,
) -> LoaderFunction:
Expand All @@ -173,7 +173,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
boundary, loaded_boundary = 0, 0

# - iterate over the shard specs
for full_dim, extra, duplicate_groups in shard_spec:
for full_dim, extra, n_groups, duplicate_groups in shard_spec:
# - full dim is the model dim (before TP).
# - extra > 0, means there is expected overall increase
# of dimensions. This is so because of replication.
Expand All @@ -185,12 +185,19 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
shard_size = full_dim // tp_size

# - compute the rank into the loaded shard.
# - if there is replication, different TP shards will
# take from the same rank.
# NOTE: currently we only support duplication
# in the case where num_groups == 1
# in the case where num_groups == 1 or num_groups divides tp size
rank = 0 if duplicate_groups else tp_rank

if duplicate_groups:
rank = 0
elif n_groups % tp_size == 0:
rank = tp_rank
else:
assert (
tp_size % n_groups == 0
), "num groups must divide TP size if TP size does not divide"
" n_groups and n_groups is not equal to 1."
rank = tp_rank // n_groups
# - leftmost boundary index into loaded weight.
loaded_skip = rank * shard_size
loaded_start_idx = loaded_boundary + loaded_skip
Expand Down Expand Up @@ -261,18 +268,16 @@ def __init__(
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
# to allocate extra space in the shard, such that groups
# may be replicated to follow the head shard.
# - NOTE: currently for the world size DOES NOT divide groups
# case, we only support the case when n_groups == 1
# - NOTE: currently if tp size DOES NOT divide groups
# case, we only support the case when n_groups divides tp size
# - IF, n_groups divides tp_size, then we duplicate each group
# tp_size times so that each group of ranks takes a group
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

assert (num_heads % self.tp_size == 0
), "Tensor parallel world size must divide num heads."

assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
"If tensor parallel world size does not divide num_heads, "
"then num_groups must equal 1.")

assert (
self.tp_size == 1 or quant_config is None
), "Tensor parallel currently not supported for quantized models."
Expand Down Expand Up @@ -322,10 +327,11 @@ def __init__(
self.n_groups * self.ssm_state_size, # expected model size
(self.n_groups - n_groups) *
self.ssm_state_size, # extra dims assigned
n_groups, # original n_groups to check for multi-groups duplication
n_groups == 1, # if there was only one group
)
intermediate_settings = (intermediate_size, 0, False)
head_settings = (self.num_heads, 0, False)
intermediate_settings = (intermediate_size, 0, 0, False)
head_settings = (self.num_heads, 0, 0, False)

# - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
Expand Down
Loading