diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 293560767..c98d18d50 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os - -import vllm.envs as envs +import vllm.envs as vllm_envs from vllm import LLM, EngineArgs from vllm.utils.argparse_utils import FlexibleArgumentParser +import tpu_inference.envs as envs from tpu_inference.core import disagg_utils @@ -87,10 +86,10 @@ def main(args: dict): 'Who wrote the novel "Pride and Prejudice"?', ] - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.start_profile() outputs = llm.generate(prompts, sampling_params) - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.stop_profile() # Print the outputs. @@ -104,7 +103,7 @@ def main(args: dict): if __name__ == "__main__": # Skip long warmup for local simple test. - os.environ['SKIP_JAX_PRECOMPILE'] = '1' + envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True parser = create_parser() args: dict = vars(parser.parse_args()) diff --git a/examples/offline_lora_inference.py b/examples/offline_lora_inference.py index 386c74e5e..6c2c3fe5e 100644 --- a/examples/offline_lora_inference.py +++ b/examples/offline_lora_inference.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import time -import vllm.envs as envs +import vllm.envs as vllm_envs from vllm import LLM, EngineArgs from vllm.lora.request import LoRARequest from vllm.utils.argparse_utils import FlexibleArgumentParser +import tpu_inference.envs as envs + def create_parser(): parser = FlexibleArgumentParser() @@ -55,13 +56,13 @@ def main(args: dict): "lora_adapter_3", 3, "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_3_adapter") - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.start_profile() start = time.perf_counter() outputs = llm.generate(prompt, sampling_params=sampling_params, lora_request=lora_request) - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.stop_profile() # Print the outputs. @@ -77,7 +78,7 @@ def main(args: dict): if __name__ == "__main__": # Skip long warmup for local simple test. - os.environ['SKIP_JAX_PRECOMPILE'] = '1' + envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True parser = create_parser() args: dict = vars(parser.parse_args()) diff --git a/examples/offline_safety_model_inference.py b/examples/offline_safety_model_inference.py index ebf736148..9fd4b94ed 100644 --- a/examples/offline_safety_model_inference.py +++ b/examples/offline_safety_model_inference.py @@ -18,12 +18,11 @@ --max-num_batched_tokens=4096 """ -import os - -import vllm.envs as envs +import vllm.envs as vllm_envs from vllm import LLM, EngineArgs from vllm.utils.argparse_utils import FlexibleArgumentParser +import tpu_inference.envs as envs from tpu_inference.core import disagg_utils @@ -170,7 +169,7 @@ def main(args: dict): prompts.append(TokensPrompt(prompt_token_ids=tokenized_prompt)) - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.start_profile() outputs = llm.generate( @@ -179,7 +178,7 @@ def main(args: dict): use_tqdm=True, ) - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.stop_profile() passed_tests = 0 @@ -220,7 +219,7 @@ def main(args: dict): if __name__ == "__main__": # Skip long warmup for local simple test. - os.environ['SKIP_JAX_PRECOMPILE'] = '1' + envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True parser = create_parser() args: dict = vars(parser.parse_args()) diff --git a/tests/runner/test_tpu_runner_mesh.py b/tests/runner/test_tpu_runner_mesh.py index cace9531d..8ab4c5dee 100644 --- a/tests/runner/test_tpu_runner_mesh.py +++ b/tests/runner/test_tpu_runner_mesh.py @@ -1,9 +1,9 @@ """Unit tests for TPUModelRunner mesh initialization.""" -import os from unittest.mock import Mock, patch import pytest +import tpu_inference.envs as envs from tpu_inference.runner.tpu_runner import TPUModelRunner @@ -53,7 +53,7 @@ def runner_instance(self, mock_vllm_config, mock_devices): def test_init_mesh_2d_model_without_device_order(self, runner_instance, mock_vllm_config): """Test 2d mesh creation without enforced device order.""" - with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \ + with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: False}), \ patch('tpu_inference.runner.tpu_runner.make_optimized_mesh') as mock_make_mesh, \ patch('tpu_inference.runner.tpu_runner.logger'): @@ -79,7 +79,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance, """Test 2d mesh creation with enforced device order.""" mock_vllm_config.sharding_config.device_indexes = [0, 1, 2, 3] - with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \ + with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: False}), \ patch('jax.make_mesh') as mock_jax_mesh, \ patch('tpu_inference.runner.tpu_runner.logger'): @@ -103,7 +103,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance, def test_init_mesh_new_model_single_slice(self, runner_instance, mock_vllm_config): """Test new model mesh creation with single slice.""" - with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': '1'}), \ + with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: True, 'NUM_SLICES': lambda: 1}), \ patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \ patch('jax.sharding.Mesh') as mock_jax_mesh, \ patch('tpu_inference.runner.tpu_runner.logger'): @@ -134,7 +134,7 @@ def test_init_mesh_new_model_multi_slice(self, runner_instance, mock_vllm_config): """Test new model mesh creation with multiple slices.""" num_slices = 2 - with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': str(num_slices)}), \ + with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: True, 'NUM_SLICES': lambda: num_slices}), \ patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \ patch('jax.sharding.Mesh') as mock_jax_mesh, \ patch('tpu_inference.runner.tpu_runner.logger'): diff --git a/tests/test_envs.py b/tests/test_envs.py index f707c1d6f..a58cdaf6c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -63,6 +63,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0") assert envs.SKIP_JAX_PRECOMPILE is False + # Test VLLM_XLA_CHECK_RECOMPILATION (default False) + assert envs.VLLM_XLA_CHECK_RECOMPILATION is False + monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1") + assert envs.VLLM_XLA_CHECK_RECOMPILATION is True + monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0") + assert envs.VLLM_XLA_CHECK_RECOMPILATION is False + # Test NEW_MODEL_DESIGN (default False) assert envs.NEW_MODEL_DESIGN is False monkeypatch.setenv("NEW_MODEL_DESIGN", "1") @@ -81,6 +88,13 @@ def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0") assert envs.PYTHON_TRACER_LEVEL == 0 + # Test NUM_SLICES (default 1) + assert envs.NUM_SLICES == 1 + monkeypatch.setenv("NUM_SLICES", "2") + assert envs.NUM_SLICES == 2 + monkeypatch.setenv("NUM_SLICES", "4") + assert envs.NUM_SLICES == 4 + def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC") @@ -134,6 +148,7 @@ def test_dir_returns_all_env_vars(): assert "JAX_PLATFORMS" in env_vars assert "TPU_NAME" in env_vars assert "SKIP_JAX_PRECOMPILE" in env_vars + assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars assert "MODEL_IMPL_TYPE" in env_vars diff --git a/tests/worker/tpu_worker_test.py b/tests/worker/tpu_worker_test.py index 4801c861a..11e2dec2b 100644 --- a/tests/worker/tpu_worker_test.py +++ b/tests/worker/tpu_worker_test.py @@ -6,6 +6,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import DraftTokenIds +import tpu_inference.envs as envs # The class we are testing from tpu_inference.worker.tpu_worker import TPUWorker @@ -280,7 +281,7 @@ def test_add_lora_not_implemented_lora_request(self, mock_vllm_config): # @patch('tpu_inference.worker.tpu_worker.jax') - @patch.dict('os.environ', {"PYTHON_TRACER_LEVEL": "1"}, clear=True) + @patch.dict(envs.environment_variables, {"PYTHON_TRACER_LEVEL": lambda: 1}) def test_profile_start(self, mock_jax, mock_vllm_config): """Tests starting the JAX profiler.""" worker = TPUWorker(vllm_config=mock_vllm_config, @@ -296,7 +297,7 @@ def test_profile_start(self, mock_jax, mock_vllm_config): args, kwargs = mock_jax.profiler.start_trace.call_args assert args[0] == "/tmp/profile_dir" # Verify options from env var were used - assert kwargs['profiler_options'].python_tracer_level == '1' + assert kwargs['profiler_options'].python_tracer_level == 1 @patch('tpu_inference.worker.tpu_worker.jax') def test_profile_stop(self, mock_jax, mock_vllm_config): diff --git a/tpu_inference/envs.py b/tpu_inference/envs.py index e97993204..82bf1f053 100644 --- a/tpu_inference/envs.py +++ b/tpu_inference/envs.py @@ -15,11 +15,13 @@ PREFILL_SLICES: str = "" DECODE_SLICES: str = "" SKIP_JAX_PRECOMPILE: bool = False + VLLM_XLA_CHECK_RECOMPILATION: bool = False MODEL_IMPL_TYPE: str = "flax_nnx" NEW_MODEL_DESIGN: bool = False PHASED_PROFILING_DIR: str = "" PYTHON_TRACER_LEVEL: int = 1 USE_MOE_EP_KERNEL: bool = False + NUM_SLICES: int = 1 RAY_USAGE_STATS_ENABLED: str = "0" VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm" @@ -48,6 +50,9 @@ # Skip JAX precompilation step during initialization "SKIP_JAX_PRECOMPILE": lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))), + # Check for XLA recompilation during execution + "VLLM_XLA_CHECK_RECOMPILATION": + lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))), # Model implementation type (e.g., "flax_nnx") "MODEL_IMPL_TYPE": lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(), @@ -63,6 +68,9 @@ # Use custom expert-parallel kernel for MoE (Mixture of Experts) "USE_MOE_EP_KERNEL": lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))), + # Number of TPU slices for multi-slice mesh + "NUM_SLICES": + lambda: int(os.getenv("NUM_SLICES", "1")), # Enable/disable Ray usage statistics collection "RAY_USAGE_STATS_ENABLED": lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"), diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index 1c411a939..26b2f621f 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -18,6 +18,7 @@ from vllm.v1.executor.ray_executor import RayWorkerMetaData from vllm.v1.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready +import tpu_inference.envs as tpu_envs from tpu_inference.logger import init_logger try: @@ -72,7 +73,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1): def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None - os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" + # Ensure Ray compiled DAG channel type is set for vLLM + os.environ[ + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = tpu_envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE # Currently, this requires USE_RAY_SPMD_WORKER=True. self.use_ray_compiled_dag = True @@ -86,10 +89,11 @@ def _init_executor(self) -> None: self._initialize_ray_cluster() placement_group = self.parallel_config.placement_group - # Disable Ray usage stats collection. - ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") - if ray_usage != "1": - os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + # Ensure Ray usage stats collection setting is propagated to Ray workers. + # Ray workers inherit environment variables, so we explicitly set this + # based on our configuration (defaults to "0" to disable stats). + os.environ[ + "RAY_USAGE_STATS_ENABLED"] = tpu_envs.RAY_USAGE_STATS_ENABLED # Create the parallel GPU workers. self._init_workers_ray(placement_group) diff --git a/tpu_inference/layers/common/sharding.py b/tpu_inference/layers/common/sharding.py index 1a1a8d169..817d7c76f 100644 --- a/tpu_inference/layers/common/sharding.py +++ b/tpu_inference/layers/common/sharding.py @@ -1,6 +1,5 @@ import json import math -import os from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, List, Optional @@ -8,7 +7,7 @@ import numpy as np from jax.sharding import Mesh -from tpu_inference import utils +from tpu_inference import envs, utils if TYPE_CHECKING: from vllm.v1.configs.vllm_config import VllmConfig @@ -48,7 +47,7 @@ class ShardingAxisName2D: try: - _use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False) + _use_base_sharding = envs.NEW_MODEL_DESIGN if _use_base_sharding: ShardingAxisName = ShardingAxisNameBase else: @@ -166,7 +165,7 @@ def validate(cls, vllm_config, sharding_strategy): f"LoRA is not supported with data parallelism " f"(DP size: {total_dp_size}). Please disable LoRA or " f"set data parallelism to 1.") - if not os.environ.get("NEW_MODEL_DESIGN", False): + if not envs.NEW_MODEL_DESIGN: raise ValueError( "Must run DP with NEW_MODEL_DESIGN enabled. Please set the " "NEW_MODEL_DESIGN=True.") diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 42b9b199d..92465d1c5 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -1,13 +1,13 @@ -import os import time from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple import jax import jax.numpy as jnp import numpy as np -import vllm.envs as envs +import vllm.envs as vllm_envs from jax.sharding import NamedSharding, PartitionSpec +import tpu_inference.envs as envs from tpu_inference.core.disagg_utils import is_disagg_enabled from tpu_inference.layers.common.attention_metadata import AttentionMetadata from tpu_inference.layers.common.sharding import ShardingAxisName @@ -30,10 +30,10 @@ class CompilationManager: def __init__(self, runner: "TPUModelRunner"): self.runner = runner - if not envs.VLLM_DISABLE_COMPILE_CACHE: + if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE: logger.info("Enabling JAX compile cache.") jax.config.update("jax_compilation_cache_dir", - envs.VLLM_XLA_CACHE_PATH) + vllm_envs.VLLM_XLA_CACHE_PATH) def _create_dummy_tensor(self, shape: Tuple[int, ...], @@ -67,8 +67,7 @@ def _run_compilation(self, name: str, fn: Callable, *args, logger.info("Compilation finished in %.2f [secs].", end - start) def capture_model(self) -> None: - if os.getenv("SKIP_JAX_PRECOMPILE", - False) or self.runner.model_config.enforce_eager: + if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager: return logger.info("Precompile all the subgraphs with possible input shapes.") diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index e76b9056b..be0d6af52 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -1,6 +1,5 @@ import copy import functools -import os import random from contextlib import nullcontext from dataclasses import dataclass @@ -11,7 +10,7 @@ import jaxtyping import numpy as np import torch -import vllm.envs as envs +import vllm.envs as vllm_envs from flax import nnx from jax.experimental import mesh_utils from jax.sharding import NamedSharding, PartitionSpec @@ -35,6 +34,7 @@ KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +import tpu_inference.envs as envs from tpu_inference import utils as common_utils from tpu_inference.layers.common.attention_metadata import AttentionMetadata from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES, @@ -292,7 +292,7 @@ def _init_random(self): self.rng_key = jax.random.key(self.model_config.seed) def _init_mesh(self) -> None: - if os.getenv("NEW_MODEL_DESIGN", False): + if envs.NEW_MODEL_DESIGN: self.mesh = self._create_new_model_mesh() else: # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need @@ -303,7 +303,7 @@ def _init_mesh(self) -> None: logger.info(f"Init mesh | mesh={self.mesh}") def _create_new_model_mesh(self) -> jax.sharding.Mesh: - num_slices = int(os.environ.get('NUM_SLICES', 1)) + num_slices = envs.NUM_SLICES logger.info(f"Creating new model mesh | devices={len(self.devices)}, " f"num_slices={num_slices}") @@ -372,7 +372,7 @@ def _create_2d_mesh(self) -> jax.sharding.Mesh: devices=self.devices) def _init_phased_profiling(self) -> None: - self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "") + self.phased_profiling_dir = envs.PHASED_PROFILING_DIR self.phase_based_profiler = None if self.phased_profiling_dir: self.phase_based_profiler = runner_utils.PhasedBasedProfiler( @@ -414,7 +414,7 @@ def _init_inputs(self) -> None: min_token_size=max(16, self.dp_size), max_token_size=scheduler_config.max_num_batched_tokens * self.dp_size, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP) self.num_tokens_paddings_per_dp = [ padding // self.dp_size for padding in self.num_tokens_paddings ] diff --git a/tpu_inference/runner/utils.py b/tpu_inference/runner/utils.py index a2d04527e..7b87989d2 100644 --- a/tpu_inference/runner/utils.py +++ b/tpu_inference/runner/utils.py @@ -15,6 +15,7 @@ from jax._src.interpreters import pxla from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput +from tpu_inference import envs from tpu_inference.logger import init_logger from tpu_inference.runner.input_batch import InputBatch @@ -306,8 +307,7 @@ def __init__(self, profile_dir: str): InferencePhase.BALANCED: False } self.default_profiling_options = jax.profiler.ProfileOptions() - self.default_profiling_options.python_tracer_level = os.getenv( - "PYTHON_TRACER_LEVEL", 0) + self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL self.current_phase: str = "" diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index efab89e07..d76f4f42f 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -347,7 +347,7 @@ def profile(self, is_start: bool = True): if is_start: options = jax.profiler.ProfileOptions() # default: https://docs.jax.dev/en/latest/profiling.html#general-options - options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0) + options.python_tracer_level = envs.PYTHON_TRACER_LEVEL options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1) jax.profiler.start_trace(self.profile_dir, profiler_options=options)