Skip to content

Conversation

@juncgu-google
Copy link
Collaborator

@juncgu-google juncgu-google commented Nov 18, 2025

Description

  1. Remove the original singleton LocalCPUBackend per process (shared by the offload scheduler and the offload worker).

    • The offload scheduler now contains a LRUOffloadManager that is responsible for all offloading decisions.
      • The size of CPU cache is set by TPU_OFFLOAD_NUM_CPU_CHUNKS.
      • The hash value of cpu chunk is inherited from vllm request's original block hash (full block only).
      • The completed save and load info is collected from kv_connector_output.kv_connector_stats.
    • The offload worker now uses a standard key-value store.
  2. Code refactor.

TODOs:

  1. The cpu cache size (num of chunks) and the staging buffer capacity are configured to # of chunks or # of tokens, instead of bytes. This limitation is caused by the lack of block size information in the offload connector scheduler.

Tests

pytest -sv tests/distributed/offload/tpu_offload_cpu_backend_test.py
pytest -sv tests/distributed/offload/tpu_offload_connector_worker_test.py
pytest -sv tests/distributed/offload/tpu_offload_connector_scheduler_test.py
pytest -sv tests/distributed/offload/tpu_offload_utils_test.py
pytest -sv tests/distributed/offload/tpu_offload_manager_test.py
pytest -sv tests/distributed/offload/tpu_offload_accuracy_test.py

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

…he decisions; worker has a regular kv store

Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
Signed-off-by: Juncheng Gu <[email protected]>
@juncgu-google juncgu-google changed the title [TPU Offload][WIP] Separate offload manager and cpu-cache backend, and code structure refactor [TPU Offload] Separate offload manager and cpu-cache backend, and code structure refactor Nov 21, 2025
Signed-off-by: Juncheng Gu <[email protected]>
@dannawang0221
Copy link
Collaborator

/lgtm

@juncgu-google juncgu-google force-pushed the cpu-offloading/dev-cpu-backend branch from 81d8fc0 to 1be2866 Compare November 22, 2025 05:00
@juncgu-google juncgu-google merged commit 0c9f039 into cpu-offloading/dev Nov 22, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants