Skip to content

Commit c54495a

Browse files
committed
downgrade torch
Signed-off-by: jiang.li <[email protected]> avoid import marlin globally Signed-off-by: jiang.li <[email protected]> llava test uses bf16 Signed-off-by: jiang.li <[email protected]> refine compile config Signed-off-by: jiang.li <[email protected]> opt cpu default batchsize Signed-off-by: jiang.li <[email protected]> format Signed-off-by: jiang.li <[email protected]> fix llava embedding Signed-off-by: jiang.li <[email protected]> format Signed-off-by: jiang.li <[email protected]> fix import Signed-off-by: jiang.li <[email protected]> Revert "avoid import marlin globally" This reverts commit d0ebbd265a443d90b99c2342abd88faf42aa9481. Signed-off-by: jiang.li <[email protected]> fix ipex quant Signed-off-by: jiang.li <[email protected]> list packages Signed-off-by: jiang.li <[email protected]> refine test deps Signed-off-by: jiang1.li <[email protected]> update compile config Signed-off-by: jiang1.li <[email protected]>
1 parent 2f1c19b commit c54495a

File tree

12 files changed

+103
-42
lines changed

12 files changed

+103
-42
lines changed

.buildkite/scripts/hardware_ci/run-cpu-test.sh

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,22 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
2424
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
2525

2626
# Run the image, setting --shm-size=4g for tensor parallel.
27-
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
28-
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
27+
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
28+
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
2929

3030
function cpu_tests() {
3131
set -e
3232
export NUMA_NODE=$2
3333

34+
# list packages
35+
docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c "
36+
set -e
37+
pip list"
38+
39+
docker exec cpu-test-"$NUMA_NODE" bash -c "
40+
set -e
41+
pip list"
42+
3443
# offline inference
3544
docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c "
3645
set -e
@@ -72,7 +81,7 @@ function cpu_tests() {
7281
set -e
7382
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
7483
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
75-
python3 benchmarks/benchmark_serving.py \
84+
VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \
7685
--backend vllm \
7786
--dataset-name random \
7887
--model facebook/opt-125m \

docker/Dockerfile.cpu

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
6666
WORKDIR /workspace/vllm
6767

6868
RUN --mount=type=cache,target=/root/.cache/uv \
69-
--mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \
69+
--mount=type=bind,src=requirements/cpu-build.txt,target=requirements/build.txt \
7070
uv pip install -r requirements/build.txt
7171

7272
COPY . .
@@ -79,6 +79,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
7979
--mount=type=bind,source=.git,target=.git \
8080
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel
8181

82+
######################### TEST DEPS #########################
83+
FROM base AS vllm-test-deps
84+
85+
WORKDIR /workspace/vllm
86+
87+
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
88+
cp requirements/test.in requirements/cpu-test.in && \
89+
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
90+
sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
91+
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
92+
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
93+
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
94+
95+
RUN --mount=type=cache,target=/root/.cache/uv \
96+
uv pip install -r requirements/cpu-test.txt
97+
8298
######################### DEV IMAGE #########################
8399
FROM vllm-build AS vllm-dev
84100

@@ -97,28 +113,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \
97113
--mount=type=bind,source=.git,target=.git \
98114
VLLM_TARGET_DEVICE=cpu python3 setup.py develop
99115

116+
COPY --from=vllm-test-deps /workspace/vllm/requirements/cpu-test.txt requirements/test.txt
117+
100118
RUN --mount=type=cache,target=/root/.cache/uv \
101-
--mount=type=bind,src=requirements/test.in,target=requirements/test.in \
102-
cp requirements/test.in requirements/test-cpu.in && \
103-
sed -i '/mamba_ssm/d' requirements/test-cpu.in && \
104-
uv pip compile requirements/test-cpu.in -o requirements/test.txt && \
105119
uv pip install -r requirements/dev.txt && \
106120
pre-commit install --hook-type pre-commit --hook-type commit-msg
107121

108122
ENTRYPOINT ["bash"]
109123

110124
######################### TEST IMAGE #########################
111-
FROM base AS vllm-test
125+
FROM vllm-test-deps AS vllm-test
112126

113127
WORKDIR /workspace/
114128

115-
RUN --mount=type=cache,target=/root/.cache/uv \
116-
--mount=type=bind,src=requirements/test.in,target=requirements/test.in \
117-
cp requirements/test.in requirements/test-cpu.in && \
118-
sed -i '/mamba_ssm/d' requirements/test-cpu.in && \
119-
uv pip compile requirements/test-cpu.in -o requirements/cpu-test.txt && \
120-
uv pip install -r requirements/cpu-test.txt
121-
122129
RUN --mount=type=cache,target=/root/.cache/uv \
123130
--mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \
124131
uv pip install dist/*.whl

requirements/cpu-build.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Temporarily used for x86 CPU backend to avoid performance regression of torch>=2.6.0+cpu,
2+
# see https://github.com/pytorch/pytorch/pull/151218
3+
cmake>=3.26.1
4+
ninja
5+
packaging>=24.2
6+
setuptools>=77.0.3,<80.0.0
7+
setuptools-scm>=8
8+
--extra-index-url https://download.pytorch.org/whl/cpu
9+
torch==2.6.0+cpu
10+
wheel
11+
jinja2>=3.1.6
12+
regex

requirements/cpu.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ numba == 0.61.2; python_version > '3.9'
88
packaging>=24.2
99
setuptools>=77.0.3,<80.0.0
1010
--extra-index-url https://download.pytorch.org/whl/cpu
11-
torch==2.7.0+cpu; platform_machine == "x86_64"
11+
torch==2.6.0+cpu; platform_machine == "x86_64" # torch>=2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
1212
torch==2.7.0; platform_system == "Darwin"
1313
torch==2.7.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
1414

@@ -26,6 +26,6 @@ triton==3.2.0; platform_machine == "x86_64"
2626

2727
# Intel Extension for PyTorch, only for x86_64 CPUs
2828
intel-openmp==2024.2.1; platform_machine == "x86_64"
29-
intel_extension_for_pytorch==2.7.0; platform_machine == "x86_64"
29+
intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>=2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
3030
py-libnuma; platform_system != "Darwin"
3131
psutil; platform_system != "Darwin"

tests/models/multimodal/generation/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
),
108108
limit_mm_per_prompt={"image": 4},
109109
)],
110+
dtype="bfloat16" if current_platform.is_cpu() else "auto",
110111
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
111112
),
112113
"paligemma": VLMTestInfo(

tests/models/multimodal/generation/vlm_utils/builders.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ def build_embedding_inputs_from_test_info(
203203

204204
images = [asset.pil_image for asset in image_assets]
205205
embeds = test_info.convert_assets_to_embeddings(image_assets)
206+
if test_info.dtype != "auto":
207+
dtype = getattr(torch, test_info.dtype) # type: ignore
208+
embeds = [e.to(dtype=dtype) for e in embeds]
206209
assert len(images) == len(model_prompts)
207210

208211
inputs = build_single_image_inputs(images, model_prompts, size_wrapper)

vllm/engine/arg_utils.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,14 +1562,20 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None:
15621562
UsageContext.LLM_CLASS: 16384,
15631563
UsageContext.OPENAI_API_SERVER: 8192,
15641564
}
1565-
default_max_num_seqs = 1024
1565+
default_max_num_seqs = {
1566+
UsageContext.LLM_CLASS: 1024,
1567+
UsageContext.OPENAI_API_SERVER: 1024,
1568+
}
15661569
else:
15671570
# TODO(woosuk): Tune the default values for other hardware.
15681571
default_max_num_batched_tokens = {
15691572
UsageContext.LLM_CLASS: 8192,
15701573
UsageContext.OPENAI_API_SERVER: 2048,
15711574
}
1572-
default_max_num_seqs = 256
1575+
default_max_num_seqs = {
1576+
UsageContext.LLM_CLASS: 256,
1577+
UsageContext.OPENAI_API_SERVER: 256,
1578+
}
15731579

15741580
# tpu specific default values.
15751581
if current_platform.is_tpu():
@@ -1586,6 +1592,17 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None:
15861592
}
15871593
}
15881594

1595+
# cpu specific default values.
1596+
if current_platform.is_cpu():
1597+
default_max_num_batched_tokens = {
1598+
UsageContext.LLM_CLASS: 4096,
1599+
UsageContext.OPENAI_API_SERVER: 2048,
1600+
}
1601+
default_max_num_seqs = {
1602+
UsageContext.LLM_CLASS: 128,
1603+
UsageContext.OPENAI_API_SERVER: 32,
1604+
}
1605+
15891606
use_context_value = usage_context.value if usage_context else None
15901607
if (self.max_num_batched_tokens is None
15911608
and usage_context in default_max_num_batched_tokens):
@@ -1606,8 +1623,9 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None:
16061623
"Setting max_num_batched_tokens to %d for %s usage context.",
16071624
self.max_num_batched_tokens, use_context_value)
16081625

1609-
if self.max_num_seqs is None:
1610-
self.max_num_seqs = default_max_num_seqs
1626+
if (self.max_num_seqs is None
1627+
and usage_context in default_max_num_seqs):
1628+
self.max_num_seqs = default_max_num_seqs[usage_context]
16111629

16121630
logger.debug("Setting max_num_seqs to %d for %s usage context.",
16131631
self.max_num_seqs, use_context_value)

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from vllm.model_executor.layers.fused_moe.layer import (
88
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
9+
from vllm.platforms import current_platform
910
from vllm.triton_utils import HAS_TRITON
1011

1112
_config: Optional[dict[str, Any]] = None
@@ -32,7 +33,7 @@ def get_config() -> Optional[dict[str, Any]]:
3233
"get_config",
3334
]
3435

35-
if HAS_TRITON:
36+
if HAS_TRITON and not current_platform.is_cpu():
3637
# import to register the custom ops
3738
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
3839
import vllm.model_executor.layers.fused_moe.fused_moe # noqa

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
if is_rocm_aiter_moe_enabled():
5454
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
5555
rocm_aiter_grouped_topk as grouped_topk)
56+
elif current_platform.is_cpu():
57+
pass
5658
else:
5759
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
5860
if current_platform.is_tpu():

vllm/model_executor/layers/quantization/ipex_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
1616
from vllm.platforms import current_platform
1717

18-
MIN_IPEX_VERSION = "2.7.0"
18+
MIN_IPEX_VERSION = "2.4.0"
1919

2020

2121
class IPEXConfig(QuantizationConfig):

vllm/platforms/cpu.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
8989
import vllm.envs as envs
9090
from vllm.utils import GiB_bytes
9191
model_config = vllm_config.model_config
92-
# Reminder: Please update docs/features/compatibility_matrix.md
93-
# If the feature combo become valid
94-
if not model_config.enforce_eager:
95-
model_config.enforce_eager = True
9692

9793
model_config.disable_cascade_attn = True
9894

@@ -171,9 +167,21 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
171167
compilation_config = vllm_config.compilation_config
172168
if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level
173169
== CompilationLevel.PIECEWISE):
170+
171+
# Note: vLLM V1 is using PIECEWISE level compilation, which will
172+
# take time to compile kernels just-in-time with the inductor
173+
# backend. For CPU CI tests, most of them are executed fast and
174+
# compilations consume too much time, even with torch compile
175+
# cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
176+
# and just execute model with dynamo + eager mode to save time.
177+
# VLLM_CPU_CI_ENV is only used as an internal variable.
178+
if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
179+
backend = "eager"
180+
else:
181+
backend = "inductor"
182+
174183
compilation_config.level = CompilationLevel.DYNAMO_ONCE
175-
compilation_config.backend = "eager"
176-
compilation_config.custom_ops += ["none"]
184+
compilation_config.backend = backend
177185
compilation_config.inductor_compile_config.update({
178186
"dce":
179187
True,

vllm/v1/worker/cpu_model_runner.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def load_model(self) -> None:
6060
def warming_up_model(self) -> None:
6161
logger.info("Warming up model for the compilation...")
6262
# Only generate graph for the generic shape
63-
self._dummy_run(max(16, self.max_num_reqs))
63+
with _set_global_compilation_settings(self.vllm_config):
64+
self._dummy_run(max(16, self.max_num_reqs))
6465
logger.info("Warming up done.")
6566

6667
def _init_device_properties(self) -> None:
@@ -71,16 +72,15 @@ def _sync_device(self) -> None:
7172

7273

7374
@contextmanager
74-
def _set_global_compilation_settings():
75+
def _set_global_compilation_settings(config: VllmConfig):
7576
import torch._inductor.config
7677

77-
# Note: The CPPGEMM backend requires freezing parameters.
78-
freezing_value = torch._inductor.config.freezing
79-
torch._inductor.config.freezing = True
80-
# Note: workaround for "ValueError: fast mode: can't pickle cyclic objects
81-
# including object type dict"
82-
force_disable_caches = torch._inductor.config.force_disable_caches
83-
torch._inductor.config.force_disable_caches = True
84-
yield
85-
torch._inductor.config.freezing = freezing_value
86-
torch._inductor.config.force_disable_caches = force_disable_caches
78+
inductor_config = config.compilation_config.inductor_compile_config
79+
try:
80+
# Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
81+
freezing_value = torch._inductor.config.freezing
82+
if inductor_config.get("max_autotune", False):
83+
torch._inductor.config.freezing = True
84+
yield
85+
finally:
86+
torch._inductor.config.freezing = freezing_value

0 commit comments

Comments
 (0)