Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os

import vllm.envs as envs
import vllm.envs as vllm_envs
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser

import tpu_inference.envs as envs
from tpu_inference.core import disagg_utils


Expand Down Expand Up @@ -87,10 +86,10 @@ def main(args: dict):
'Who wrote the novel "Pride and Prejudice"?',
]

if envs.VLLM_TORCH_PROFILER_DIR is not None:
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
llm.start_profile()
outputs = llm.generate(prompts, sampling_params)
if envs.VLLM_TORCH_PROFILER_DIR is not None:
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
llm.stop_profile()

# Print the outputs.
Expand All @@ -104,7 +103,7 @@ def main(args: dict):

if __name__ == "__main__":
# Skip long warmup for local simple test.
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True

parser = create_parser()
args: dict = vars(parser.parse_args())
Expand Down
11 changes: 6 additions & 5 deletions examples/offline_lora_inference.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
import time

import vllm.envs as envs
import vllm.envs as vllm_envs
from vllm import LLM, EngineArgs
from vllm.lora.request import LoRARequest
from vllm.utils.argparse_utils import FlexibleArgumentParser

import tpu_inference.envs as envs


def create_parser():
parser = FlexibleArgumentParser()
Expand Down Expand Up @@ -55,13 +56,13 @@ def main(args: dict):
"lora_adapter_3", 3,
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_3_adapter")

if envs.VLLM_TORCH_PROFILER_DIR is not None:
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
llm.start_profile()
start = time.perf_counter()
outputs = llm.generate(prompt,
sampling_params=sampling_params,
lora_request=lora_request)
if envs.VLLM_TORCH_PROFILER_DIR is not None:
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
llm.stop_profile()

# Print the outputs.
Expand All @@ -77,7 +78,7 @@ def main(args: dict):

if __name__ == "__main__":
# Skip long warmup for local simple test.
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True

parser = create_parser()
args: dict = vars(parser.parse_args())
Expand Down
11 changes: 5 additions & 6 deletions examples/offline_safety_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
--max-num_batched_tokens=4096
"""

import os

import vllm.envs as envs
import vllm.envs as vllm_envs
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser

import tpu_inference.envs as envs
from tpu_inference.core import disagg_utils


Expand Down Expand Up @@ -170,7 +169,7 @@ def main(args: dict):

prompts.append(TokensPrompt(prompt_token_ids=tokenized_prompt))

if envs.VLLM_TORCH_PROFILER_DIR is not None:
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
llm.start_profile()

outputs = llm.generate(
Expand All @@ -179,7 +178,7 @@ def main(args: dict):
use_tqdm=True,
)

if envs.VLLM_TORCH_PROFILER_DIR is not None:
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
llm.stop_profile()

passed_tests = 0
Expand Down Expand Up @@ -220,7 +219,7 @@ def main(args: dict):

if __name__ == "__main__":
# Skip long warmup for local simple test.
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True

parser = create_parser()
args: dict = vars(parser.parse_args())
Expand Down
10 changes: 5 additions & 5 deletions tests/runner/test_tpu_runner_mesh.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Unit tests for TPUModelRunner mesh initialization."""
import os
from unittest.mock import Mock, patch

import pytest

import tpu_inference.envs as envs
from tpu_inference.runner.tpu_runner import TPUModelRunner


Expand Down Expand Up @@ -53,7 +53,7 @@ def runner_instance(self, mock_vllm_config, mock_devices):
def test_init_mesh_2d_model_without_device_order(self, runner_instance,
mock_vllm_config):
"""Test 2d mesh creation without enforced device order."""
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: False}), \
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh') as mock_make_mesh, \
patch('tpu_inference.runner.tpu_runner.logger'):

Expand All @@ -79,7 +79,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
"""Test 2d mesh creation with enforced device order."""
mock_vllm_config.sharding_config.device_indexes = [0, 1, 2, 3]

with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: False}), \
patch('jax.make_mesh') as mock_jax_mesh, \
patch('tpu_inference.runner.tpu_runner.logger'):

Expand All @@ -103,7 +103,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
def test_init_mesh_new_model_single_slice(self, runner_instance,
mock_vllm_config):
"""Test new model mesh creation with single slice."""
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': '1'}), \
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: True, 'NUM_SLICES': lambda: 1}), \
patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
patch('jax.sharding.Mesh') as mock_jax_mesh, \
patch('tpu_inference.runner.tpu_runner.logger'):
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_init_mesh_new_model_multi_slice(self, runner_instance,
mock_vllm_config):
"""Test new model mesh creation with multiple slices."""
num_slices = 2
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': str(num_slices)}), \
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: True, 'NUM_SLICES': lambda: num_slices}), \
patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
patch('jax.sharding.Mesh') as mock_jax_mesh, \
patch('tpu_inference.runner.tpu_runner.logger'):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
assert envs.SKIP_JAX_PRECOMPILE is False

# Test VLLM_XLA_CHECK_RECOMPILATION (default False)
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False

# Test NEW_MODEL_DESIGN (default False)
assert envs.NEW_MODEL_DESIGN is False
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
Expand All @@ -81,6 +88,13 @@ def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
assert envs.PYTHON_TRACER_LEVEL == 0

# Test NUM_SLICES (default 1)
assert envs.NUM_SLICES == 1
monkeypatch.setenv("NUM_SLICES", "2")
assert envs.NUM_SLICES == 2
monkeypatch.setenv("NUM_SLICES", "4")
assert envs.NUM_SLICES == 4


def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
Expand Down Expand Up @@ -134,6 +148,7 @@ def test_dir_returns_all_env_vars():
assert "JAX_PLATFORMS" in env_vars
assert "TPU_NAME" in env_vars
assert "SKIP_JAX_PRECOMPILE" in env_vars
assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
assert "MODEL_IMPL_TYPE" in env_vars


Expand Down
5 changes: 3 additions & 2 deletions tests/worker/tpu_worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import DraftTokenIds

import tpu_inference.envs as envs
# The class we are testing
from tpu_inference.worker.tpu_worker import TPUWorker

Expand Down Expand Up @@ -280,7 +281,7 @@ def test_add_lora_not_implemented_lora_request(self, mock_vllm_config):
#

@patch('tpu_inference.worker.tpu_worker.jax')
@patch.dict('os.environ', {"PYTHON_TRACER_LEVEL": "1"}, clear=True)
@patch.dict(envs.environment_variables, {"PYTHON_TRACER_LEVEL": lambda: 1})
def test_profile_start(self, mock_jax, mock_vllm_config):
"""Tests starting the JAX profiler."""
worker = TPUWorker(vllm_config=mock_vllm_config,
Expand All @@ -296,7 +297,7 @@ def test_profile_start(self, mock_jax, mock_vllm_config):
args, kwargs = mock_jax.profiler.start_trace.call_args
assert args[0] == "/tmp/profile_dir"
# Verify options from env var were used
assert kwargs['profiler_options'].python_tracer_level == '1'
assert kwargs['profiler_options'].python_tracer_level == 1

@patch('tpu_inference.worker.tpu_worker.jax')
def test_profile_stop(self, mock_jax, mock_vllm_config):
Expand Down
8 changes: 8 additions & 0 deletions tpu_inference/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
PREFILL_SLICES: str = ""
DECODE_SLICES: str = ""
SKIP_JAX_PRECOMPILE: bool = False
VLLM_XLA_CHECK_RECOMPILATION: bool = False
MODEL_IMPL_TYPE: str = "flax_nnx"
NEW_MODEL_DESIGN: bool = False
PHASED_PROFILING_DIR: str = ""
PYTHON_TRACER_LEVEL: int = 1
USE_MOE_EP_KERNEL: bool = False
NUM_SLICES: int = 1
RAY_USAGE_STATS_ENABLED: str = "0"
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"

Expand Down Expand Up @@ -48,6 +50,9 @@
# Skip JAX precompilation step during initialization
"SKIP_JAX_PRECOMPILE":
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
# Check for XLA recompilation during execution
"VLLM_XLA_CHECK_RECOMPILATION":
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))),
# Model implementation type (e.g., "flax_nnx")
"MODEL_IMPL_TYPE":
lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
Expand All @@ -63,6 +68,9 @@
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
"USE_MOE_EP_KERNEL":
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
# Number of TPU slices for multi-slice mesh
"NUM_SLICES":
lambda: int(os.getenv("NUM_SLICES", "1")),
# Enable/disable Ray usage statistics collection
"RAY_USAGE_STATS_ENABLED":
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
Expand Down
14 changes: 9 additions & 5 deletions tpu_inference/executors/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.v1.executor.ray_executor import RayWorkerMetaData
from vllm.v1.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready

import tpu_inference.envs as tpu_envs
from tpu_inference.logger import init_logger

try:
Expand Down Expand Up @@ -72,7 +73,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None

os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# Ensure Ray compiled DAG channel type is set for vLLM
os.environ[
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = tpu_envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE

# Currently, this requires USE_RAY_SPMD_WORKER=True.
self.use_ray_compiled_dag = True
Expand All @@ -86,10 +89,11 @@ def _init_executor(self) -> None:
self._initialize_ray_cluster()
placement_group = self.parallel_config.placement_group

# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Ensure Ray usage stats collection setting is propagated to Ray workers.
# Ray workers inherit environment variables, so we explicitly set this
# based on our configuration (defaults to "0" to disable stats).
os.environ[
"RAY_USAGE_STATS_ENABLED"] = tpu_envs.RAY_USAGE_STATS_ENABLED

# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
Expand Down
7 changes: 3 additions & 4 deletions tpu_inference/layers/common/sharding.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json
import math
import os
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, List, Optional

import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh

from tpu_inference import utils
from tpu_inference import envs, utils

if TYPE_CHECKING:
from vllm.v1.configs.vllm_config import VllmConfig
Expand Down Expand Up @@ -48,7 +47,7 @@ class ShardingAxisName2D:


try:
_use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
_use_base_sharding = envs.NEW_MODEL_DESIGN
if _use_base_sharding:
ShardingAxisName = ShardingAxisNameBase
else:
Expand Down Expand Up @@ -166,7 +165,7 @@ def validate(cls, vllm_config, sharding_strategy):
f"LoRA is not supported with data parallelism "
f"(DP size: {total_dp_size}). Please disable LoRA or "
f"set data parallelism to 1.")
if not os.environ.get("NEW_MODEL_DESIGN", False):
if not envs.NEW_MODEL_DESIGN:
raise ValueError(
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
"NEW_MODEL_DESIGN=True.")
Expand Down
11 changes: 5 additions & 6 deletions tpu_inference/runner/compilation_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import time
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import vllm.envs as envs
import vllm.envs as vllm_envs
from jax.sharding import NamedSharding, PartitionSpec

import tpu_inference.envs as envs
from tpu_inference.core.disagg_utils import is_disagg_enabled
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.common.sharding import ShardingAxisName
Expand All @@ -30,10 +30,10 @@ class CompilationManager:

def __init__(self, runner: "TPUModelRunner"):
self.runner = runner
if not envs.VLLM_DISABLE_COMPILE_CACHE:
if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
logger.info("Enabling JAX compile cache.")
jax.config.update("jax_compilation_cache_dir",
envs.VLLM_XLA_CACHE_PATH)
vllm_envs.VLLM_XLA_CACHE_PATH)

def _create_dummy_tensor(self,
shape: Tuple[int, ...],
Expand Down Expand Up @@ -67,8 +67,7 @@ def _run_compilation(self, name: str, fn: Callable, *args,
logger.info("Compilation finished in %.2f [secs].", end - start)

def capture_model(self) -> None:
if os.getenv("SKIP_JAX_PRECOMPILE",
False) or self.runner.model_config.enforce_eager:
if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
return
logger.info("Precompile all the subgraphs with possible input shapes.")

Expand Down
Loading