Skip to content

[BugFix] Fix multi-node offline data parallel #19937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -615,13 +615,16 @@ steps:
- vllm/executor/
- vllm/model_executor/models/
- tests/distributed/
- tests/examples/offline_inference/data_parallel.py
commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code

- label: Distributed Tests (2 GPUs) # 40min
mirror_hardwares: [amdexperimental]
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,8 @@ def _run_engine(
pbar.update(n)
else:
pbar.update(1)
if pbar.n == num_requests:
pbar.refresh()

if use_tqdm:
pbar.close()
Expand Down
8 changes: 6 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,16 @@ def run_busy_loop(self):
local_unfinished_reqs)

if not self.engines_running:
if self.dp_rank == 0:
if self.dp_rank == 0 or not self.has_coordinator:
# Notify client that we are pausing the loop.
logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave)
# In the coordinator case, dp rank 0 sends updates to the
# coordinator. Otherwise (offline spmd case), each rank
# sends the update to its colocated front-end process.
client_index = -1 if self.has_coordinator else 0
self.output_queue.put_nowait(
(-1,
(client_index,
EngineCoreOutputs(wave_complete=self.current_wave)))
Comment on lines +887 to 890
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a comment explaining why client_index is set to -1 when has_coordinator is true, and 0 otherwise. This will improve code readability.

Suggested change
client_index = -1 if self.has_coordinator else 0
self.output_queue.put_nowait(
(-1,
(client_index,
EngineCoreOutputs(wave_complete=self.current_wave)))
client_index = -1 if self.has_coordinator else 0
# In the coordinator case, dp rank 0 sends updates to the
# coordinator. Otherwise (offline spmd case), each rank
# sends the update to its colocated front-end process.

self.current_wave += 1

Expand Down
20 changes: 19 additions & 1 deletion vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def collective_rpc(self,
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError

def dp_engines_running(self) -> bool:
"""Returns True id data parallel engines are collectively in a
running state."""
raise NotImplementedError

async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError

Expand Down Expand Up @@ -282,6 +287,9 @@ def collective_rpc(self,
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)

def dp_engines_running(self) -> bool:
return False


@dataclass
class BackgroundResources:
Expand Down Expand Up @@ -384,6 +392,9 @@ def __init__(
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank

# State used for data parallel.
self.engines_running = False

# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
Expand Down Expand Up @@ -539,6 +550,9 @@ def free_pending_messages(self):
while self.pending_messages and self.pending_messages[-1][0].done:
self.pending_messages.pop()

def dp_engines_running(self) -> bool:
return self.engines_running


def _process_utility_output(output: UtilityOutput,
utility_results: dict[int, AnyFuture]):
Expand All @@ -562,6 +576,7 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats=log_stats,
)

self.is_dp = self.vllm_config.parallel_config.data_parallel_size > 1
self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]()

# Ensure that the outputs socket processing thread does not have
Expand Down Expand Up @@ -623,6 +638,8 @@ def get_output(self) -> EngineCoreOutputs:
outputs = self.outputs_queue.get()
if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None
if outputs.wave_complete is not None:
self.engines_running = False
return outputs

def _send_input(self, request_type: EngineCoreRequestType, request: Any):
Expand Down Expand Up @@ -650,6 +667,8 @@ def call_utility(self, method: str, *args) -> Any:
return future.result()

def add_request(self, request: EngineCoreRequest) -> None:
if self.is_dp:
self.engines_running = True
self._send_input(EngineCoreRequestType.ADD, request)

def abort_requests(self, request_ids: list[str]) -> None:
Expand Down Expand Up @@ -911,7 +930,6 @@ def __init__(self,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0):
self.current_wave = 0
self.engines_running = False
# To route aborts to the correct engine.
self.reqs_in_flight: dict[str, CoreEngine] = {}

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_num_unfinished_requests(self) -> int:
def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
if self.dp_group is None:
return has_unfinished
return has_unfinished or self.engine_core.dp_engines_running()
return self.has_unfinished_requests_dp(has_unfinished)

def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
Expand Down