Skip to content

Commit c4cf260

Browse files
authored
[Perf][CLI] Improve overall startup time (#19941)
1 parent 33d51f5 commit c4cf260

File tree

14 files changed

+293
-103
lines changed

14 files changed

+293
-103
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ repos:
115115
entry: python tools/check_spdx_header.py
116116
language: python
117117
types: [python]
118+
- id: check-root-lazy-imports
119+
name: Check root lazy imports
120+
entry: python tools/check_init_lazy_imports.py
121+
language: python
122+
types: [python]
118123
- id: check-filenames
119124
name: Check for spaces in all filenames
120125
entry: bash

tools/check_init_lazy_imports.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Ensure we perform lazy loading in vllm/__init__.py.
4+
i.e: appears only within the ``if typing.TYPE_CHECKING:`` guard,
5+
**except** for a short whitelist.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import ast
11+
import pathlib
12+
import sys
13+
from collections.abc import Iterable
14+
from typing import Final
15+
16+
REPO_ROOT: Final = pathlib.Path(__file__).resolve().parent.parent
17+
INIT_PATH: Final = REPO_ROOT / "vllm" / "__init__.py"
18+
19+
# If you need to add items to whitelist, do it here.
20+
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset({
21+
"vllm.env_override",
22+
})
23+
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset({
24+
".version",
25+
})
26+
27+
28+
def _is_internal(name: str | None, *, level: int = 0) -> bool:
29+
if level > 0:
30+
return True
31+
if name is None:
32+
return False
33+
return name.startswith("vllm.") or name == "vllm"
34+
35+
36+
def _fail(violations: Iterable[tuple[int, str]]) -> None:
37+
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n",
38+
file=sys.stderr)
39+
for lineno, msg in violations:
40+
print(f" Line {lineno}: {msg}", file=sys.stderr)
41+
sys.exit(1)
42+
43+
44+
def main() -> None:
45+
source = INIT_PATH.read_text(encoding="utf-8")
46+
tree = ast.parse(source, filename=str(INIT_PATH))
47+
48+
violations: list[tuple[int, str]] = []
49+
50+
class Visitor(ast.NodeVisitor):
51+
52+
def __init__(self) -> None:
53+
super().__init__()
54+
self._in_type_checking = False
55+
56+
def visit_If(self, node: ast.If) -> None:
57+
guard_is_type_checking = False
58+
test = node.test
59+
if isinstance(test, ast.Attribute) and isinstance(
60+
test.value, ast.Name):
61+
guard_is_type_checking = (test.value.id == "typing"
62+
and test.attr == "TYPE_CHECKING")
63+
elif isinstance(test, ast.Name):
64+
guard_is_type_checking = test.id == "TYPE_CHECKING"
65+
66+
if guard_is_type_checking:
67+
prev = self._in_type_checking
68+
self._in_type_checking = True
69+
for child in node.body:
70+
self.visit(child)
71+
self._in_type_checking = prev
72+
for child in node.orelse:
73+
self.visit(child)
74+
else:
75+
self.generic_visit(node)
76+
77+
def visit_Import(self, node: ast.Import) -> None:
78+
if self._in_type_checking:
79+
return
80+
for alias in node.names:
81+
module_name = alias.name
82+
if _is_internal(
83+
module_name) and module_name not in ALLOWED_IMPORTS:
84+
violations.append((
85+
node.lineno,
86+
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
87+
))
88+
89+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
90+
if self._in_type_checking:
91+
return
92+
module_as_written = ("." * node.level) + (node.module or "")
93+
if _is_internal(
94+
node.module, level=node.level
95+
) and module_as_written not in ALLOWED_FROM_MODULES:
96+
violations.append((
97+
node.lineno,
98+
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
99+
))
100+
101+
Visitor().visit(tree)
102+
103+
if violations:
104+
_fail(violations)
105+
106+
107+
if __name__ == "__main__":
108+
main()

vllm/__init__.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,72 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
4+
45
# The version.py should be independent library, and we always import the
56
# version library first. Such assumption is critical for some customization.
67
from .version import __version__, __version_tuple__ # isort:skip
78

9+
import typing
10+
811
# The environment variables override should be imported before any other
912
# modules to ensure that the environment variables are set before any
1013
# other modules are imported.
11-
import vllm.env_override # isort:skip # noqa: F401
12-
13-
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
14-
from vllm.engine.async_llm_engine import AsyncLLMEngine
15-
from vllm.engine.llm_engine import LLMEngine
16-
from vllm.entrypoints.llm import LLM
17-
from vllm.executor.ray_utils import initialize_ray_cluster
18-
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
19-
from vllm.model_executor.models import ModelRegistry
20-
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
21-
CompletionOutput, EmbeddingOutput,
22-
EmbeddingRequestOutput, PoolingOutput,
23-
PoolingRequestOutput, RequestOutput, ScoringOutput,
24-
ScoringRequestOutput)
25-
from vllm.pooling_params import PoolingParams
26-
from vllm.sampling_params import SamplingParams
14+
import vllm.env_override # noqa: F401
15+
16+
MODULE_ATTRS = {
17+
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
18+
"EngineArgs": ".engine.arg_utils:EngineArgs",
19+
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
20+
"LLMEngine": ".engine.llm_engine:LLMEngine",
21+
"LLM": ".entrypoints.llm:LLM",
22+
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
23+
"PromptType": ".inputs:PromptType",
24+
"TextPrompt": ".inputs:TextPrompt",
25+
"TokensPrompt": ".inputs:TokensPrompt",
26+
"ModelRegistry": ".model_executor.models:ModelRegistry",
27+
"SamplingParams": ".sampling_params:SamplingParams",
28+
"PoolingParams": ".pooling_params:PoolingParams",
29+
"ClassificationOutput": ".outputs:ClassificationOutput",
30+
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
31+
"CompletionOutput": ".outputs:CompletionOutput",
32+
"EmbeddingOutput": ".outputs:EmbeddingOutput",
33+
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
34+
"PoolingOutput": ".outputs:PoolingOutput",
35+
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
36+
"RequestOutput": ".outputs:RequestOutput",
37+
"ScoringOutput": ".outputs:ScoringOutput",
38+
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
39+
}
40+
41+
if typing.TYPE_CHECKING:
42+
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
43+
from vllm.engine.async_llm_engine import AsyncLLMEngine
44+
from vllm.engine.llm_engine import LLMEngine
45+
from vllm.entrypoints.llm import LLM
46+
from vllm.executor.ray_utils import initialize_ray_cluster
47+
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
48+
from vllm.model_executor.models import ModelRegistry
49+
from vllm.outputs import (ClassificationOutput,
50+
ClassificationRequestOutput, CompletionOutput,
51+
EmbeddingOutput, EmbeddingRequestOutput,
52+
PoolingOutput, PoolingRequestOutput,
53+
RequestOutput, ScoringOutput,
54+
ScoringRequestOutput)
55+
from vllm.pooling_params import PoolingParams
56+
from vllm.sampling_params import SamplingParams
57+
else:
58+
59+
def __getattr__(name: str) -> typing.Any:
60+
from importlib import import_module
61+
62+
if name in MODULE_ATTRS:
63+
module_name, attr_name = MODULE_ATTRS[name].split(":")
64+
module = import_module(module_name, __package__)
65+
return getattr(module, attr_name)
66+
else:
67+
raise AttributeError(
68+
f'module {__package__} has no attribute {name}')
69+
2770

2871
__all__ = [
2972
"__version__",

vllm/config.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
2929
from torch.distributed import ProcessGroup, ReduceOp
3030
from transformers import PretrainedConfig
31-
from typing_extensions import deprecated, runtime_checkable
31+
from typing_extensions import Self, deprecated, runtime_checkable
3232

3333
import vllm.envs as envs
3434
from vllm import version
@@ -1537,7 +1537,6 @@ def compute_hash(self) -> str:
15371537
def __post_init__(self) -> None:
15381538
self.swap_space_bytes = self.swap_space * GiB_bytes
15391539

1540-
self._verify_args()
15411540
self._verify_cache_dtype()
15421541
self._verify_prefix_caching()
15431542

@@ -1546,7 +1545,8 @@ def metrics_info(self):
15461545
# metrics info
15471546
return {key: str(value) for key, value in self.__dict__.items()}
15481547

1549-
def _verify_args(self) -> None:
1548+
@model_validator(mode='after')
1549+
def _verify_args(self) -> Self:
15501550
if self.cpu_offload_gb < 0:
15511551
raise ValueError("CPU offload space must be non-negative"
15521552
f", but got {self.cpu_offload_gb}")
@@ -1556,6 +1556,8 @@ def _verify_args(self) -> None:
15561556
"GPU memory utilization must be less than 1.0. Got "
15571557
f"{self.gpu_memory_utilization}.")
15581558

1559+
return self
1560+
15591561
def _verify_cache_dtype(self) -> None:
15601562
if self.cache_dtype == "auto":
15611563
pass
@@ -1942,15 +1944,14 @@ def __post_init__(self) -> None:
19421944
if self.distributed_executor_backend is None and self.world_size == 1:
19431945
self.distributed_executor_backend = "uni"
19441946

1945-
self._verify_args()
1946-
19471947
@property
19481948
def use_ray(self) -> bool:
19491949
return self.distributed_executor_backend == "ray" or (
19501950
isinstance(self.distributed_executor_backend, type)
19511951
and self.distributed_executor_backend.uses_ray)
19521952

1953-
def _verify_args(self) -> None:
1953+
@model_validator(mode='after')
1954+
def _verify_args(self) -> Self:
19541955
# Lazy import to avoid circular import
19551956
from vllm.executor.executor_base import ExecutorBase
19561957
from vllm.platforms import current_platform
@@ -1977,8 +1978,7 @@ def _verify_args(self) -> None:
19771978
raise ValueError("Unable to use nsight profiling unless workers "
19781979
"run with Ray.")
19791980

1980-
assert isinstance(self.worker_extension_cls, str), (
1981-
"worker_extension_cls must be a string (qualified class name).")
1981+
return self
19821982

19831983

19841984
PreemptionMode = Literal["swap", "recompute"]
@@ -2202,9 +2202,8 @@ def __post_init__(self) -> None:
22022202
self.max_num_partial_prefills, self.max_long_partial_prefills,
22032203
self.long_prefill_token_threshold)
22042204

2205-
self._verify_args()
2206-
2207-
def _verify_args(self) -> None:
2205+
@model_validator(mode='after')
2206+
def _verify_args(self) -> Self:
22082207
if (self.max_num_batched_tokens < self.max_model_len
22092208
and not self.chunked_prefill_enabled):
22102209
raise ValueError(
@@ -2263,6 +2262,8 @@ def _verify_args(self) -> None:
22632262
"must be greater than or equal to 1 and less than or equal to "
22642263
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
22652264

2265+
return self
2266+
22662267
@property
22672268
def is_multi_step(self) -> bool:
22682269
return self.num_scheduler_steps > 1
@@ -2669,8 +2670,6 @@ def __post_init__(self):
26692670
if self.posterior_alpha is None:
26702671
self.posterior_alpha = 0.3
26712672

2672-
self._verify_args()
2673-
26742673
@staticmethod
26752674
def _maybe_override_draft_max_model_len(
26762675
speculative_max_model_len: Optional[int],
@@ -2761,7 +2760,8 @@ def create_draft_parallel_config(
27612760

27622761
return draft_parallel_config
27632762

2764-
def _verify_args(self) -> None:
2763+
@model_validator(mode='after')
2764+
def _verify_args(self) -> Self:
27652765
if self.num_speculative_tokens is None:
27662766
raise ValueError(
27672767
"num_speculative_tokens must be provided with "
@@ -2812,6 +2812,8 @@ def _verify_args(self) -> None:
28122812
"Eagle3 is only supported for Llama models. "
28132813
f"Got {self.target_model_config.hf_text_config.model_type=}")
28142814

2815+
return self
2816+
28152817
@property
28162818
def num_lookahead_slots(self) -> int:
28172819
"""The number of additional slots the scheduler should allocate per

vllm/engine/arg_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
# yapf: disable
55
import argparse
6+
import copy
67
import dataclasses
8+
import functools
79
import json
810
import sys
911
import threading
@@ -168,7 +170,8 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
168170
return type_hints
169171

170172

171-
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
173+
@functools.lru_cache(maxsize=30)
174+
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
172175
cls_docs = get_attr_docs(cls)
173176
kwargs = {}
174177
for field in fields(cls):
@@ -269,6 +272,16 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
269272
return kwargs
270273

271274

275+
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
276+
"""Return argparse kwargs for the given Config dataclass.
277+
278+
The heavy computation is cached via functools.lru_cache, and a deep copy
279+
is returned so callers can mutate the dictionary without affecting the
280+
cached version.
281+
"""
282+
return copy.deepcopy(_compute_kwargs(cls))
283+
284+
272285
@dataclass
273286
class EngineArgs:
274287
"""Arguments for vLLM engine."""

vllm/entrypoints/cli/benchmark/main.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from __future__ import annotations
5+
36
import argparse
7+
import typing
48

59
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
610
from vllm.entrypoints.cli.types import CLISubcommand
7-
from vllm.utils import FlexibleArgumentParser
11+
12+
if typing.TYPE_CHECKING:
13+
from vllm.utils import FlexibleArgumentParser
814

915

1016
class BenchmarkSubcommand(CLISubcommand):
@@ -23,7 +29,6 @@ def validate(self, args: argparse.Namespace) -> None:
2329
def subparser_init(
2430
self,
2531
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
26-
2732
bench_parser = subparsers.add_parser(
2833
self.name,
2934
help=self.help,

0 commit comments

Comments
 (0)