diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fe775bb370f..d6c9ee680ab 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -615,13 +615,16 @@ steps: - vllm/executor/ - vllm/model_executor/models/ - tests/distributed/ + - tests/examples/offline_inference/data_parallel.py commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code - label: Distributed Tests (2 GPUs) # 40min mirror_hardwares: [amdexperimental] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 05e0be61ada..63967e4d2d4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1568,6 +1568,8 @@ def _run_engine( pbar.update(n) else: pbar.update(1) + if pbar.n == num_requests: + pbar.refresh() if use_tqdm: pbar.close() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index da65550354d..453ed364dc8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -877,12 +877,16 @@ def run_busy_loop(self): local_unfinished_reqs) if not self.engines_running: - if self.dp_rank == 0: + if self.dp_rank == 0 or not self.has_coordinator: # Notify client that we are pausing the loop. logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) + # In the coordinator case, dp rank 0 sends updates to the + # coordinator. Otherwise (offline spmd case), each rank + # sends the update to its colocated front-end process. + client_index = -1 if self.has_coordinator else 0 self.output_queue.put_nowait( - (-1, + (client_index, EngineCoreOutputs(wave_complete=self.current_wave))) self.current_wave += 1 diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 8058cd3127d..856310df588 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -155,6 +155,11 @@ def collective_rpc(self, kwargs: Optional[dict[str, Any]] = None) -> list[_R]: raise NotImplementedError + def dp_engines_running(self) -> bool: + """Returns True id data parallel engines are collectively in a + running state.""" + raise NotImplementedError + async def get_output_async(self) -> EngineCoreOutputs: raise NotImplementedError @@ -282,6 +287,9 @@ def collective_rpc(self, kwargs: Optional[dict[str, Any]] = None) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) + def dp_engines_running(self) -> bool: + return False + @dataclass class BackgroundResources: @@ -384,6 +392,9 @@ def __init__( dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank + # State used for data parallel. + self.engines_running = False + # SPMD mode is where there is an LLM instance per DP rank and # one core engine per LLM, see # examples/offline_inference/data_parallel.py. @@ -539,6 +550,9 @@ def free_pending_messages(self): while self.pending_messages and self.pending_messages[-1][0].done: self.pending_messages.pop() + def dp_engines_running(self) -> bool: + return self.engines_running + def _process_utility_output(output: UtilityOutput, utility_results: dict[int, AnyFuture]): @@ -562,6 +576,7 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats=log_stats, ) + self.is_dp = self.vllm_config.parallel_config.data_parallel_size > 1 self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]() # Ensure that the outputs socket processing thread does not have @@ -623,6 +638,8 @@ def get_output(self) -> EngineCoreOutputs: outputs = self.outputs_queue.get() if isinstance(outputs, Exception): raise self._format_exception(outputs) from None + if outputs.wave_complete is not None: + self.engines_running = False return outputs def _send_input(self, request_type: EngineCoreRequestType, request: Any): @@ -650,6 +667,8 @@ def call_utility(self, method: str, *args) -> Any: return future.result() def add_request(self, request: EngineCoreRequest) -> None: + if self.is_dp: + self.engines_running = True self._send_input(EngineCoreRequestType.ADD, request) def abort_requests(self, request_ids: list[str]) -> None: @@ -911,7 +930,6 @@ def __init__(self, client_addresses: Optional[dict[str, str]] = None, client_index: int = 0): self.current_wave = 0 - self.engines_running = False # To route aborts to the correct engine. self.reqs_in_flight: dict[str, CoreEngine] = {} diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 1932cd10bb1..25fab271311 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -160,7 +160,7 @@ def get_num_unfinished_requests(self) -> int: def has_unfinished_requests(self) -> bool: has_unfinished = self.output_processor.has_unfinished_requests() if self.dp_group is None: - return has_unfinished + return has_unfinished or self.engine_core.dp_engines_running() return self.has_unfinished_requests_dp(has_unfinished) def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: