Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/e2e/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def _run_inference_with_config(model_name: str,
time.sleep(5)


@pytest.mark.parametrize("model_impl_type", ["vllm", "flax_nnx"])
def test_model_data_parallelism(
test_prompts: list,
sampling_params: SamplingParams,
model_impl_type: str,
):
"""
Test model-wise data parallelism where data=2 in the mesh axis.
Expand All @@ -95,6 +97,7 @@ def test_model_data_parallelism(
"""
# Use Llama 1B for this test
test_model = "meta-llama/Llama-3.2-1B-Instruct"
os.environ['MODEL_IMPL_TYPE'] = model_impl_type

# Test with data parallelism enabled
outputs = _run_inference_with_config(
Expand All @@ -103,6 +106,7 @@ def test_model_data_parallelism(
sampling_params=sampling_params,
tensor_parallel_size=1,
data_parallel_size=2,
async_scheduling=True,
)

# Verify we got outputs for all prompts
Expand Down Expand Up @@ -175,7 +179,7 @@ def test_data_parallelism_correctness(
"""
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
model_name = "meta-llama/Llama-3.2-1B-Instruct"
# Use a smaller subset of prompts for correctness testing
small_prompts = test_prompts[:10]

Expand Down
3 changes: 2 additions & 1 deletion tpu_inference/layers/common/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,10 @@ def validate(cls, vllm_config, sharding_strategy):
f"LoRA is not supported with data parallelism "
f"(DP size: {total_dp_size}). Please disable LoRA or "
f"set data parallelism to 1.")
if sharding_strategy.attention_data_parallelism > 1:
if not os.environ.get("NEW_MODEL_DESIGN", False):
raise ValueError(
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
"NEW_MODEL_DESIGN=True.")

@property
Expand Down
99 changes: 74 additions & 25 deletions tpu_inference/layers/vllm/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel(
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
g)

_gmm = functools.partial(
gmm,
Expand All @@ -123,14 +124,26 @@ def tensor_sharded_gmm_merged_column_parallel(
gmm_result = shard_map(
_gmm,
mesh=mesh,
in_specs=(P(), P(None, "model", None), P()),
out_specs=(P(None, "model")),
in_specs=(P("data", None), P(None, "model", None), P("data")),
out_specs=(P("data", "model")),
check_rep=False,
)(lhs, rhs, group_sizes)

if rhs_bias is not None:
rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)

def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
rhs_bis = jnp.repeat(rhs_bias_local,
group_sizes_global,
0,
total_repeat_length=m // mesh.shape["data"])
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)

gmm_result = shard_map(
_add_bias,
mesh=mesh,
in_specs=(P("data", "model"), P(None, "model"), P("data")),
out_specs=(P("data", "model")),
)(gmm_result, rhs_bias, group_sizes)

n_shards = mesh.shape["model"]
output_sizes = [intermediate_size, intermediate_size]
Expand All @@ -150,7 +163,8 @@ def tensor_sharded_gmm_row_parallel(
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
g)

_gmm = functools.partial(
gmm,
Expand All @@ -167,14 +181,25 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
gmm_result = shard_map(
_gmm_all_reduce,
mesh=mesh,
in_specs=(P(None, "model"), P(None, None, "model"), P()),
out_specs=(P()),
in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
out_specs=(P("data")),
check_rep=False,
)(lhs, rhs, group_sizes)

if rhs_bias is not None:
rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)

def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
rhs_bis = jnp.repeat(rhs_bias_local,
group_sizes_global,
0,
total_repeat_length=m // mesh.shape["data"])
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)

gmm_result = shard_map(
_add_bias,
mesh=mesh,
in_specs=(P("data"), P(), P("data")),
out_specs=(P("data")),
)(gmm_result, rhs_bias, group_sizes)

return gmm_result

Expand Down Expand Up @@ -366,15 +391,27 @@ def fused_moe_func(
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
topk_weights = topk_weights.astype(dtype)

topk_indices_flat = topk_indices.flatten()
topk_argsort_indices = jnp.argsort(topk_indices_flat)
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
token_indices_sorted = token_indices[topk_argsort_indices]
group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)

x = hidden_states[token_indices_sorted]

def _process_tokens_locally(hidden_states_local, topk_indices_local):
num_tokens_local = hidden_states_local.shape[0]
topk_indices_flat = topk_indices_local.flatten()
topk_argsort_indices = jnp.argsort(topk_indices_flat)
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
token_indices = jnp.arange(num_tokens_local,
dtype=jnp.int32).repeat(topk)
token_indices_sorted = token_indices[topk_argsort_indices]
group_sizes_local = jnp.bincount(topk_indices_flat,
length=global_num_experts)

x = hidden_states_local[token_indices_sorted]
return x, group_sizes_local, topk_argsort_revert_indices

x, group_sizes, topk_argsort_revert_indices = shard_map(
_process_tokens_locally,
mesh=mesh,
in_specs=(P("data", None), P("data", None)),
out_specs=(P("data", None), P("data"), P("data")),
check_rep=False,
)(hidden_states, topk_indices)
if use_ep:
x = expert_sharded_gmm(
x,
Expand Down Expand Up @@ -411,7 +448,7 @@ def fused_moe_func(
)
else:
x = jax.lax.with_sharding_constraint(
x, NamedSharding(mesh, P(None, "model")))
x, NamedSharding(mesh, P("data", "model")))
x = tensor_sharded_gmm_row_parallel(
x,
w2,
Expand All @@ -421,13 +458,25 @@ def fused_moe_func(
mesh=mesh,
)

x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
x = x * jnp.expand_dims(topk_weights, axis=-1)
x = x.sum(axis=-2)
def _finalize_output(x_local, topk_argsort_revert_indices_local,
topk_weights_local):
x_local = x_local[topk_argsort_revert_indices_local].reshape(
-1, topk, hidden_size)
x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
x_local = x_local.sum(axis=-2)
return x_local

x = shard_map(
_finalize_output,
mesh=mesh,
in_specs=(P("data", None), P("data"), P("data", None)),
out_specs=(P("data", None)),
check_rep=False,
)(x, topk_argsort_revert_indices, topk_weights)
x = x.reshape(orig_shape)

if reduce_results:
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P("data")))
return x


Expand Down
7 changes: 6 additions & 1 deletion tpu_inference/layers/vllm/quantization/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
" bad performance.", type(layer))

self.bias_sharding = P(self.weight_sharding[0])
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
if isinstance(self.weight_sharding[0], tuple):
self.n_shards = 1
for axis in self.weight_sharding[0]:
self.n_shards *= self.mesh.shape.get(axis, 1)
else:
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)

def get_input_sharding(self, x: torchax.tensor.Tensor):
if self.enable_sequence_parallelism:
Expand Down