diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index bbcde4009c0..8db8c3a05fb 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -24,13 +24,22 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e export NUMA_NODE=$2 + # list packages + docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c " + set -e + pip list" + + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pip list" + # offline inference docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c " set -e @@ -72,7 +81,7 @@ function cpu_tests() { set -e python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half & timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 - python3 benchmarks/benchmark_serving.py \ + VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \ --backend vllm \ --dataset-name random \ --model facebook/opt-125m \ diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 3e9fa0e7af2..13bd03c5696 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -66,7 +66,7 @@ ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} WORKDIR /workspace/vllm RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \ + --mount=type=bind,src=requirements/cpu-build.txt,target=requirements/build.txt \ uv pip install -r requirements/build.txt COPY . . @@ -79,6 +79,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=.git,target=.git \ VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel +######################### TEST DEPS ######################### +FROM base AS vllm-test-deps + +WORKDIR /workspace/vllm + +RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ + cp requirements/test.in requirements/cpu-test.in && \ + sed -i '/mamba_ssm/d' requirements/cpu-test.in && \ + sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \ + sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \ + sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \ + uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install -r requirements/cpu-test.txt + ######################### DEV IMAGE ######################### FROM vllm-build AS vllm-dev @@ -97,28 +113,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=.git,target=.git \ VLLM_TARGET_DEVICE=cpu python3 setup.py develop +COPY --from=vllm-test-deps /workspace/vllm/requirements/cpu-test.txt requirements/test.txt + RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ - cp requirements/test.in requirements/test-cpu.in && \ - sed -i '/mamba_ssm/d' requirements/test-cpu.in && \ - uv pip compile requirements/test-cpu.in -o requirements/test.txt && \ uv pip install -r requirements/dev.txt && \ pre-commit install --hook-type pre-commit --hook-type commit-msg ENTRYPOINT ["bash"] ######################### TEST IMAGE ######################### -FROM base AS vllm-test +FROM vllm-test-deps AS vllm-test WORKDIR /workspace/ -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ - cp requirements/test.in requirements/test-cpu.in && \ - sed -i '/mamba_ssm/d' requirements/test-cpu.in && \ - uv pip compile requirements/test-cpu.in -o requirements/cpu-test.txt && \ - uv pip install -r requirements/cpu-test.txt - RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \ uv pip install dist/*.whl diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt new file mode 100644 index 00000000000..12f2e86da3f --- /dev/null +++ b/requirements/cpu-build.txt @@ -0,0 +1,12 @@ +# Temporarily used for x86 CPU backend to avoid performance regression of torch>=2.6.0+cpu, +# see https://github.com/pytorch/pytorch/pull/151218 +cmake>=3.26.1 +ninja +packaging>=24.2 +setuptools>=77.0.3,<80.0.0 +setuptools-scm>=8 +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0+cpu +wheel +jinja2>=3.1.6 +regex diff --git a/requirements/cpu.txt b/requirements/cpu.txt index d7b0fc6d80a..595083e80e6 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -8,7 +8,7 @@ numba == 0.61.2; python_version > '3.9' packaging>=24.2 setuptools>=77.0.3,<80.0.0 --extra-index-url https://download.pytorch.org/whl/cpu -torch==2.7.0+cpu; platform_machine == "x86_64" +torch==2.6.0+cpu; platform_machine == "x86_64" # torch>=2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218 torch==2.7.0; platform_system == "Darwin" torch==2.7.0; platform_machine == "ppc64le" or platform_machine == "aarch64" @@ -26,6 +26,6 @@ triton==3.2.0; platform_machine == "x86_64" # Intel Extension for PyTorch, only for x86_64 CPUs intel-openmp==2024.2.1; platform_machine == "x86_64" -intel_extension_for_pytorch==2.7.0; platform_machine == "x86_64" +intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>=2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218 py-libnuma; platform_system != "Darwin" psutil; platform_system != "Darwin" diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 496850b19af..e6179619d26 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -107,6 +107,7 @@ ), limit_mm_per_prompt={"image": 4}, )], + dtype="bfloat16" if current_platform.is_cpu() else "auto", marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "paligemma": VLMTestInfo( diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index 7d20dd66089..03c08240d6a 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -203,6 +203,9 @@ def build_embedding_inputs_from_test_info( images = [asset.pil_image for asset in image_assets] embeds = test_info.convert_assets_to_embeddings(image_assets) + if test_info.dtype != "auto": + dtype = getattr(torch, test_info.dtype) # type: ignore + embeds = [e.to(dtype=dtype) for e in embeds] assert len(images) == len(model_prompts) inputs = build_single_image_inputs(images, model_prompts, size_wrapper) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 85b7bbfbd93..f599d7a3bb5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1562,14 +1562,20 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: UsageContext.LLM_CLASS: 16384, UsageContext.OPENAI_API_SERVER: 8192, } - default_max_num_seqs = 1024 + default_max_num_seqs = { + UsageContext.LLM_CLASS: 1024, + UsageContext.OPENAI_API_SERVER: 1024, + } else: # TODO(woosuk): Tune the default values for other hardware. default_max_num_batched_tokens = { UsageContext.LLM_CLASS: 8192, UsageContext.OPENAI_API_SERVER: 2048, } - default_max_num_seqs = 256 + default_max_num_seqs = { + UsageContext.LLM_CLASS: 256, + UsageContext.OPENAI_API_SERVER: 256, + } # tpu specific default values. if current_platform.is_tpu(): @@ -1586,6 +1592,17 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: } } + # cpu specific default values. + if current_platform.is_cpu(): + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 4096, + UsageContext.OPENAI_API_SERVER: 2048, + } + default_max_num_seqs = { + UsageContext.LLM_CLASS: 128, + UsageContext.OPENAI_API_SERVER: 32, + } + use_context_value = usage_context.value if usage_context else None if (self.max_num_batched_tokens is None and usage_context in default_max_num_batched_tokens): @@ -1606,8 +1623,9 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: "Setting max_num_batched_tokens to %d for %s usage context.", self.max_num_batched_tokens, use_context_value) - if self.max_num_seqs is None: - self.max_num_seqs = default_max_num_seqs + if (self.max_num_seqs is None + and usage_context in default_max_num_seqs): + self.max_num_seqs = default_max_num_seqs[usage_context] logger.debug("Setting max_num_seqs to %d for %s usage context.", self.max_num_seqs, use_context_value) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 2bdc96e297c..c30879add69 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -6,6 +6,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON _config: Optional[dict[str, Any]] = None @@ -32,7 +33,7 @@ def get_config() -> Optional[dict[str, Any]]: "get_config", ] -if HAS_TRITON: +if HAS_TRITON and not current_platform.is_cpu(): # import to register the custom ops import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cf8e4ee6509..729f9391554 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -53,6 +53,8 @@ if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_grouped_topk as grouped_topk) +elif current_platform.is_cpu(): + pass else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 31ad96eccaf..8ceab149bd5 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.platforms import current_platform -MIN_IPEX_VERSION = "2.7.0" +MIN_IPEX_VERSION = "2.4.0" class IPEXConfig(QuantizationConfig): diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 27c591e3bab..2d10d700fa2 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -89,10 +89,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: import vllm.envs as envs from vllm.utils import GiB_bytes model_config = vllm_config.model_config - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - if not model_config.enforce_eager: - model_config.enforce_eager = True model_config.disable_cascade_attn = True @@ -171,9 +167,21 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: compilation_config = vllm_config.compilation_config if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level == CompilationLevel.PIECEWISE): + + # Note: vLLM V1 is using PIECEWISE level compilation, which will + # take time to compile kernels just-in-time with the inductor + # backend. For CPU CI tests, most of them are executed fast and + # compilations consume too much time, even with torch compile + # cache. So use VLLM_CPU_CI_ENV to indicate the CI environment, + # and just execute model with dynamo + eager mode to save time. + # VLLM_CPU_CI_ENV is only used as an internal variable. + if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0": + backend = "eager" + else: + backend = "inductor" + compilation_config.level = CompilationLevel.DYNAMO_ONCE - compilation_config.backend = "eager" - compilation_config.custom_ops += ["none"] + compilation_config.backend = backend compilation_config.inductor_compile_config.update({ "dce": True, diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 607cfc0ef69..6631c9636ea 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -60,7 +60,8 @@ def load_model(self) -> None: def warming_up_model(self) -> None: logger.info("Warming up model for the compilation...") # Only generate graph for the generic shape - self._dummy_run(max(16, self.max_num_reqs)) + with _set_global_compilation_settings(self.vllm_config): + self._dummy_run(max(16, self.max_num_reqs)) logger.info("Warming up done.") def _init_device_properties(self) -> None: @@ -71,16 +72,15 @@ def _sync_device(self) -> None: @contextmanager -def _set_global_compilation_settings(): +def _set_global_compilation_settings(config: VllmConfig): import torch._inductor.config - # Note: The CPPGEMM backend requires freezing parameters. - freezing_value = torch._inductor.config.freezing - torch._inductor.config.freezing = True - # Note: workaround for "ValueError: fast mode: can't pickle cyclic objects - # including object type dict" - force_disable_caches = torch._inductor.config.force_disable_caches - torch._inductor.config.force_disable_caches = True - yield - torch._inductor.config.freezing = freezing_value - torch._inductor.config.force_disable_caches = force_disable_caches + inductor_config = config.compilation_config.inductor_compile_config + try: + # Note: The MKLDNN and CPPGEMM backend requires freezing parameters. + freezing_value = torch._inductor.config.freezing + if inductor_config.get("max_autotune", False): + torch._inductor.config.freezing = True + yield + finally: + torch._inductor.config.freezing = freezing_value