Skip to content

[V1 Scheduler] BatchScheduler to balance token-based microbatches and reduce GPU pipeline bubbles #19873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,20 @@ class SchedulerConfig:
default scheduler. Can be a class directly or the path to a class of form
"mod.custom_class"."""

use_batch_scheduler: bool = False
"""Whether to use the BatchScheduler instead of the default scheduler.

If set to True, the engine will use
"vllm.v1.core.sched.scheduler.BatchScheduler" as the scheduler class unless
a custom `scheduler_cls` is explicitly provided.

If both `use_batch_scheduler=True` and a non-default `scheduler_cls` are
specified, the `scheduler_cls` will take precedence and
`use_batch_scheduler` will be ignored.

Default is False.
"""

disable_hybrid_kv_cache_manager: bool = False
"""If set to True, KV cache manager will allocate the same size of KV cache
for all attention layers even if there are multiple type of attention layers
Expand Down
16 changes: 16 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ class EngineArgs:
disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
use_batch_scheduler: bool = SchedulerConfig.use_batch_scheduler

override_neuron_config: dict[str, Any] = \
get_field(ModelConfig, "override_neuron_config")
Expand Down Expand Up @@ -868,6 +869,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**scheduler_kwargs["disable_chunked_mm_input"])
scheduler_group.add_argument("--scheduler-cls",
**scheduler_kwargs["scheduler_cls"])
scheduler_group.add_argument("--use-batch-scheduler",
**scheduler_kwargs["use_batch_scheduler"])
scheduler_group.add_argument(
"--disable-hybrid-kv-cache-manager",
**scheduler_kwargs["disable_hybrid_kv_cache_manager"])
Expand Down Expand Up @@ -1195,6 +1198,7 @@ def create_engine_config(
and parallel_config.use_ray),
policy=self.scheduling_policy,
scheduler_cls=self.scheduler_cls,
use_batch_scheduler=self.use_batch_scheduler,
max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
Expand Down Expand Up @@ -1570,6 +1574,18 @@ def _set_default_args_v1(self, usage_context: UsageContext,
if not self.enable_chunked_prefill:
self.max_num_batched_tokens = model_config.max_model_len

if self.use_batch_scheduler:
if self.scheduler_cls == EngineArgs.scheduler_cls:
self.scheduler_cls = \
"vllm.v1.core.sched.scheduler.BatchScheduler"
else:
logger.warning(
"use_batch_scheduler is set to True, "
"but a custom scheduler_cls is also provided. "
"The specified scheduler_cls (%s) will take precedence, "
"and use_batch_scheduler will be ignored.",
self.scheduler_cls)

# V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls:
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

if TYPE_CHECKING:
import numpy as np
Expand Down Expand Up @@ -155,3 +155,13 @@ class SchedulerOutput:

# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None


@dataclass
class ScheduledRequest:
request_id: str
num_new_tokens: int
encoder_inputs_to_schedule: list[int] | None
num_scheduled_spec_tokens: int
spec_token_ids: list[int] | None
request_data: Union[NewRequestData, CachedRequestData]
Loading