-
Notifications
You must be signed in to change notification settings - Fork 597
Add deterministic RL training experiment with vLLM #1975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
bwasti
wants to merge
14
commits into
pytorch:main
Choose a base branch
from
bwasti:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3,357
−0
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
5fc60f8
Add deterministic RL training experiment with vLLM
bwasti 2823b41
Add missing weight conversion files and fix imports
teja-rao aea9af1
Address review comments
teja-rao 62b20b6
Remove redundant round-trip weight loading
teja-rao 3c18115
Optimize vLLM weight reloading using collective_rpc
teja-rao 7be6112
Add comprehensive gradient checking and debugging
bwasti 602e92e
Fix missing gradients for FFN layers by adding SiluAndMul backward pass
bwasti d53740a
Remove gradient debugging statements
bwasti 53b56a5
Update readme
bwasti 90bc444
Cleanup readme and remove unneeded flashv3 backward
bwasti ae4ebd3
lint
bwasti 5c00c9a
More cleanup
bwasti 1d41a08
Remove extra function
bwasti 44c9946
add better loss controls and a real dataset
bwasti File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,256 @@ | ||
| # Deterministic RL Training with vLLM | ||
|
|
||
| This experiment combines vLLM's deterministic kernels with PyTorch autograd to enable reinforcement learning training where forward passes produce bitwise-identical results across runs. | ||
|
|
||
| ## Overview | ||
|
|
||
| RL training requires both fast inference for generating rollouts and gradient computation for policy updates. vLLM provides deterministic forward passes but does not support gradients. This experiment adds backward passes to vLLM's operations. | ||
|
|
||
| The implementation: | ||
| 1. Uses vLLM's batch-invariant kernels for forward passes | ||
| 2. Implements custom backward passes for gradient computation | ||
| 3. Provides weight conversion utilities between TorchTitan and vLLM formats | ||
|
|
||
| ### Features | ||
|
|
||
| - Bitwise determinism: Same inputs produce identical outputs across runs | ||
| - Gradient support: Backward passes through vLLM operations | ||
| - Weight conversion: Utilities to convert between model formats | ||
|
|
||
| Note: Currently supports single-device training only. | ||
|
|
||
| ## Architecture | ||
|
|
||
| ### Components | ||
|
|
||
| 1. `models/attention.py`: VLLMCompatibleFlashAttention | ||
| - Uses vLLM's Flash Attention for forward pass | ||
| - Implements custom backward pass for gradient computation | ||
| - Uses `num_splits=1` for deterministic behavior | ||
|
|
||
| 2. `models/qwen3/model_vllm_compat.py`: Qwen3VLLMCompatModel | ||
| - Qwen3 model with merged gate/up projections matching vLLM format | ||
| - Uses VLLMRMSNorm with gradient support | ||
|
|
||
| 3. `batch_invariant_backward.py`: Backward passes for vLLM operations | ||
| - Registers gradients for vLLM's batch-invariant operations | ||
| - Supports matmul, linear, and RMSNorm | ||
| - Patches Flash Attention for autograd | ||
|
|
||
| 4. `weights_vllm_compat.py`: Weight conversion utilities | ||
| - Converts between TorchTitan format (separate w1, w2, w3) and vLLM format (merged gate_up_proj) | ||
| - Provides bidirectional conversion functions | ||
|
|
||
| 5. `simple_rl.py`: RL training loop | ||
| - Generates rollouts using vLLM engine | ||
| - Computes advantages using GRPO-style ranking | ||
| - Updates policy using PPO | ||
|
|
||
| ## Installation | ||
|
|
||
| ### Prerequisites | ||
|
|
||
| ```bash | ||
| # Install vLLM with deterministic support | ||
| pip install vllm | ||
|
|
||
| # Install TorchTitan (from the repository root) | ||
| pip install -e . | ||
|
|
||
| # Install additional dependencies | ||
| pip install transformers safetensors huggingface_hub tensorboard | ||
| ``` | ||
|
|
||
| ### Enable Batch Invariance | ||
|
|
||
| Initialize vLLM's batch-invariant mode before training: | ||
|
|
||
| ```python | ||
| from vllm.model_executor.layers.batch_invariant import init_batch_invariance | ||
| init_batch_invariance() | ||
| ``` | ||
|
|
||
| ## Usage | ||
|
|
||
| ### Quick Start | ||
|
|
||
| ```python | ||
| import torch | ||
| from vllm.model_executor.layers.batch_invariant import init_batch_invariance | ||
| from torchtitan.experiments.deterministic_vllm_rl import ( | ||
| enable_batch_invariant_backward_mode, | ||
| Qwen3VLLMCompatModel, | ||
| ) | ||
|
|
||
| # 1. Enable deterministic mode | ||
| init_batch_invariance() | ||
| enable_batch_invariant_backward_mode() | ||
|
|
||
| # 2. Load model | ||
| from torchtitan.models.qwen3.model.args import Qwen3ModelArgs | ||
| model_args = Qwen3ModelArgs( | ||
| dim=2048, | ||
| n_layers=24, | ||
| n_heads=16, | ||
| n_kv_heads=2, | ||
| vocab_size=151936, | ||
| ) | ||
| model = Qwen3VLLMCompatModel(model_args) | ||
|
|
||
| # 3. Forward pass (deterministic) | ||
| input_ids = torch.randint(0, 151936, (2, 128), device='cuda') | ||
| logits = model(input_ids) | ||
|
|
||
| # 4. Backward pass | ||
| loss = logits.sum() | ||
| loss.backward() | ||
| ``` | ||
|
|
||
| ### Full RL Training | ||
|
|
||
| Run the RL training loop: | ||
|
|
||
| ```bash | ||
| VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl | ||
| ``` | ||
|
|
||
| This will: | ||
| 1. Download Qwen3-1.7B from HuggingFace | ||
| 2. Initialize vLLM engine for rollouts | ||
| 3. Generate samples for training prompts | ||
| 4. Compute rewards and advantages | ||
| 5. Update the policy using PPO | ||
| 6. Log metrics to TensorBoard | ||
|
|
||
| View training progress: | ||
| ```bash | ||
| tensorboard --logdir=./outputs/rl_training | ||
| ``` | ||
|
|
||
| ## How It Works | ||
|
|
||
| ### Deterministic Forward Pass | ||
|
|
||
| vLLM's batch-invariant mode makes operations deterministic: | ||
|
|
||
| ```python | ||
| # These operations are deterministic when batch_invariance is enabled | ||
| y = torch.matmul(a, b) # Uses vLLM's deterministic matmul | ||
| output = flash_attn_varlen_func(q, k, v, num_splits=1) # Deterministic FA | ||
| ``` | ||
|
|
||
| ### Backward Pass with Gradients | ||
|
|
||
| Custom backward passes: | ||
| 1. Re-compute attention weights deterministically | ||
| 2. Use standard chain rule for gradients | ||
| 3. Apply gradients through vLLM's deterministic operations | ||
|
|
||
| ```python | ||
| class FlashAttnWithBackward(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, q, k, v, ...): | ||
| # Use vLLM's forward implementation | ||
| return flash_attn_varlen_func(q, k, v, num_splits=1, ...) | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| # Compute gradients deterministically | ||
| # (re-compute attention weights and apply chain rule) | ||
| return grad_q, grad_k, grad_v, ... | ||
| ``` | ||
|
|
||
| ### Bitwise Determinism Verification | ||
|
|
||
| The training loop compares logprobs from vLLM and TorchTitan: | ||
|
|
||
| ```python | ||
| # During training, compare logprobs | ||
| vllm_logprobs = [from vLLM rollout] | ||
| titan_logprobs = [from TorchTitan forward pass] | ||
|
|
||
| assert torch.equal(vllm_logprobs, titan_logprobs) | ||
| ``` | ||
|
|
||
| ## Testing | ||
|
|
||
| Run the test suite: | ||
|
|
||
| ```bash | ||
| cd torchtitan/experiments/deterministic_vllm_rl/tests | ||
|
|
||
| # Test backward passes | ||
| python test_batch_invariant_backward.py | ||
|
|
||
| # Test determinism | ||
| python test_exact_determinism.py | ||
| ``` | ||
|
|
||
| ## Technical Details | ||
|
|
||
| ### Why Determinism Matters for RL | ||
|
|
||
| RL training steps: | ||
| 1. Generate rollouts by sampling from the policy | ||
| 2. Compute rewards based on the samples | ||
| 3. Update the policy using gradients | ||
|
|
||
| If the forward pass during training differs from the forward pass during rollout, policy gradients may be incorrect. This matters for algorithms like PPO that compare old and new policy probabilities. | ||
|
|
||
| This implementation uses the same kernels for both rollouts (vLLM) and training (TorchTitan) to ensure `logprobs_rollout == logprobs_training` bitwise. | ||
|
|
||
| ### Performance | ||
|
|
||
| - Rollout speed: Uses vLLM's optimized kernels | ||
| - Training speed: Similar to standard TorchTitan | ||
| - Memory: Saves activations for custom backward passes | ||
|
|
||
| ### Limitations | ||
|
|
||
| 1. Custom backward requires uniform sequence lengths | ||
| 2. Only causal attention is supported | ||
| 3. Requires NVIDIA GPUs with Flash Attention support | ||
|
|
||
| ## Project Structure | ||
bwasti marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| ``` | ||
| deterministic_vllm_rl/ | ||
| ├── README.md # Documentation | ||
| ├── __init__.py # Package initialization | ||
| ├── batch_invariant_backward.py # Backward passes for vLLM ops | ||
| ├── weights_vllm_compat.py # Weight conversion utilities | ||
| ├── simple_rl.py # RL training loop | ||
| ├── models/ | ||
| │ ├── __init__.py | ||
| │ ├── attention.py # VLLMCompatibleFlashAttention | ||
| │ └── qwen3/ | ||
| │ ├── __init__.py | ||
| │ └── model_vllm_compat.py # vLLM-compatible Qwen3 model | ||
| ├── weights/ | ||
| │ ├── __init__.py | ||
| │ ├── converter.py # Weight conversion script | ||
| │ └── README.md # Weight conversion documentation | ||
| └── tests/ | ||
| ├── __init__.py | ||
| ├── test_batch_invariant_backward.py # Test backward passes | ||
| └── test_exact_determinism.py # Test determinism | ||
| ``` | ||
|
|
||
| ## Contributing | ||
|
|
||
| This experiment is part of TorchTitan. To contribute: | ||
|
|
||
| 1. Test your changes with `pytest tests/` | ||
| 2. Verify bitwise determinism is maintained | ||
| 3. Update this README if adding new features | ||
|
|
||
| ## References | ||
|
|
||
| - [vLLM Documentation](https://docs.vllm.ai/) | ||
| - [Flash Attention Paper](https://arxiv.org/abs/2205.14135) | ||
| - [PPO Algorithm](https://arxiv.org/abs/1707.06347) | ||
| - [GRPO: Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) | ||
|
|
||
| ## License | ||
|
|
||
| This code is licensed under the BSD-style license found in the LICENSE file in the TorchTitan repository root directory. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| Deterministic RL training with vLLM experiment. | ||
|
|
||
| This experiment provides tools for bitwise-deterministic reinforcement learning | ||
| training using vLLM for fast rollouts and TorchTitan for training. | ||
|
|
||
| Key components: | ||
| - VLLMCompatibleFlashAttention: Flash attention with custom backward pass | ||
| - Qwen3VLLMCompatModel: vLLM-compatible model with merged projections | ||
| - batch_invariant_backward: Gradient support for vLLM's deterministic operations | ||
| - simple_rl: End-to-end RL training loop | ||
| """ | ||
|
|
||
| from .batch_invariant_backward import ( | ||
| enable_batch_invariant_backward_mode, | ||
| rms_norm_with_gradients, | ||
| silu_and_mul_with_gradients, | ||
| ) | ||
| from .models import VLLMCompatibleFlashAttention | ||
| from .models.qwen3 import Qwen3VLLMCompatModel | ||
|
|
||
| __all__ = [ | ||
| "VLLMCompatibleFlashAttention", | ||
| "Qwen3VLLMCompatModel", | ||
| "enable_batch_invariant_backward_mode", | ||
| "rms_norm_with_gradients", | ||
| "silu_and_mul_with_gradients", | ||
| ] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we mention that this only works for single device right now, and we plan to extend it to work with parallelisms?