Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions lm_eval/models/vllm_causallms.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(
"swap_space": int(swap_space),
"quantization": quantization,
"seed": int(seed),
"data_parallel_size": int(data_parallel_size),
"enable_lora": True if lora_local_path else False,
"max_lora_rank": int(max_lora_rank),
}
Expand All @@ -181,8 +182,18 @@ def __init__(
if isinstance(batch_size, str) and "auto" in batch_size
else int(batch_size)
)

self.dp_group = None
self.torch_dist = None
self.is_external_launcher_dp = self.model_args.get("distributed_executor_backend", None) == "external_launcher" and self.data_parallel_size > 1
if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args)
elif self.is_external_launcher_dp:
self.model = LLM(**self.model_args)
from vllm.distributed.parallel_state import get_dp_group
import torch.distributed as dist
self.dp_group = get_dp_group()
self.torch_dist = dist
else:
eval_logger.warning(
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
Expand Down Expand Up @@ -417,6 +428,22 @@ def run_inference_one_model(
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
)
procs, resq = [], Queue()
if self.is_external_launcher_dp:
dp_rank = self.model.llm_engine.vllm_config.parallel_config.data_parallel_rank
local_requests = list(requests)[dp_rank]
local_sampling_params = list(sampling_params)[dp_rank]
local_results = self.model.generate(
[TokensPrompt(prompt_token_ids=request) for request in local_requests],
sampling_params=local_sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
)
# All gather results across data parallel group
assert self.dp_group is not None
assert self.torch_dist is not None
# Gather results from all DP ranks
gathered_results = [None] * self.dp_group.world_size
self.torch_dist.all_gather_object(gathered_results, local_results, group=self.dp_group.cpu_group)
return undistribute(gathered_results)
# We use Process as it is non-daemonic
try:
for rank, (sp, req) in enumerate(zip(requests, sampling_params)):
Expand Down