-
Notifications
You must be signed in to change notification settings - Fork 597
Add deterministic RL training experiment with vLLM #1975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This experiment provides a complete framework for bitwise-deterministic reinforcement learning training that combines vLLM for fast rollouts and TorchTitan for training with gradients. Key features: - Bitwise-deterministic forward and backward passes - vLLM-compatible Qwen3 model with merged projections - Custom Flash Attention with gradient support - Gradient support for vLLM's batch-invariant operations - Complete RL training loop with GRPO-style advantages - Comprehensive test suite for determinism verification Components: - models/attention.py: VLLMCompatibleFlashAttention - models/qwen3/model_vllm_compat.py: vLLM-compatible Qwen3 model - batch_invariant_backward.py: Gradient support for vLLM operations - simple_rl.py: End-to-end RL training loop - tests/: Test suite for backward passes and determinism
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I left some questions and comments.
|
|
||
| ### Performance | ||
|
|
||
| - **Rollout speed**: ~100x faster than standard PyTorch (thanks to vLLM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a totally random number lmao -- kv cache does all the work. thanks claude for adding this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you help add the disclaimer to this file too?
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar, and also the two test files
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we mention that this only works for single device right now, and we plan to extend it to work with parallelisms?
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| """ | ||
| Backward pass using batch-invariant PyTorch operations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Batch-invariant backward sounds nice! But I guess batch-invariance doesn't matter too much for backward? E.g. you'll lose batch invariance when doing loss reduction?
Asking because the decomposed implementation below might be slower than a non-batch-invariant kernel.
| import numpy as np | ||
| from torch.utils.tensorboard import SummaryWriter | ||
|
|
||
| import torchtitan.experiments.compat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not used and will error?
|
|
||
| import torchtitan.experiments.compat | ||
| from torchtitan.models.qwen3.model.args import Qwen3ModelArgs | ||
| from weights_vllm_compat import torchtitan_to_vllm_compat, vllm_compat_to_torchtitan |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These files are missing?
| import torchtitan.experiments.compat | ||
| from torchtitan.models.qwen3.model.args import Qwen3ModelArgs | ||
| from weights_vllm_compat import torchtitan_to_vllm_compat, vllm_compat_to_torchtitan | ||
| from weights.converter import torchtitan_to_vllm, vllm_to_torchtitan |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
| import glob | ||
| from safetensors.torch import load_file as sf_load | ||
| from weights.converter import vllm_to_torchtitan | ||
| from weights_vllm_compat import torchtitan_to_vllm_compat as titan_to_vllm_compat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems the same as few lines above
| shard_files = sorted(glob.glob(os.path.join(vllm_engine.temp_model_dir, "model-*.safetensors"))) | ||
| if shard_files: | ||
| # Load all shards | ||
| all_disk_state = {} | ||
| for shard_file in shard_files: | ||
| shard_state = sf_load(shard_file) | ||
| all_disk_state.update(shard_state) | ||
|
|
||
| # Convert back to TorchTitan format | ||
| titan_from_disk = vllm_to_torchtitan(all_disk_state) | ||
|
|
||
| if use_vllm_compat: | ||
| # Convert to vLLM-compat format for vLLM-compatible model | ||
| weights_for_model = titan_to_vllm_compat(titan_from_disk) | ||
| else: | ||
| # Use standard TorchTitan format for standard model | ||
| weights_for_model = titan_from_disk | ||
|
|
||
| # Load back into model | ||
| model.load_state_dict(weights_for_model, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks a bit confusing. I thought loading from disk should be done only once, instead of in each rl_update_step. Is it related to the checkpoint bug you mentioned earlier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so I believe this is done to because we suspected that there may be some rounding issues in weights transfer but i verified removing this code and it works fine. my experience on checkpointing, unless there is a dtype conversion, weight loading should always be bitwise in sync. so, i will send an update removing this.
This commit adds the missing weight conversion utilities that are required by Bram's base commit (simple_rl.py imports them but they were missing): - weights_vllm_compat.py: Converts between TorchTitan and vLLM-compat formats (merges/splits gate_up_proj for FFN layers) - weights/converter.py: Converts between vLLM HuggingFace and TorchTitan formats - weights/__init__.py: Package init - weights/README.md: Documentation for weight converters Import fixes: - simple_rl.py: Use local models.qwen3 instead of torchtitan.models.qwen3.model - model_vllm_compat.py: Import VLLMCompatibleFlashAttention from local ..attention and Qwen3ModelArgs from torchtitan.models.qwen3.model.args
torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py
Show resolved
Hide resolved
torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py
Outdated
Show resolved
Hide resolved
- Add BSD-style license headers to all new files: * batch_invariant_backward.py * simple_rl.py * tests/test_batch_invariant_backward.py * tests/test_exact_determinism.py * weights_vllm_compat.py * weights/converter.py * weights/__init__.py - Add note about single-device limitation in README.md Currently supports single-device training only; future work will extend to distributed training with parallelism - Remove unused imports in simple_rl.py: * Remove 'import torchtitan.experiments.compat' (unused) * Remove duplicate imports of torchtitan_to_vllm_compat - Fix all imports to use absolute paths for python -m compatibility: * Update model_vllm_compat.py to import from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward * Update simple_rl.py to import from torchtitan.experiments.deterministic_vllm_rl modules * Removes sys.path manipulation - now works cleanly with python -m - Remove duplicate RMSNormFunction from model_vllm_compat.py: * Import rms_norm_with_gradients from batch_invariant_backward.py * Remove duplicate RMSNormFunction class and function definition * Keeps gradient-enabled operations centralized in utilities module
The round-trip weight loading code was reloading weights from disk back into the TorchTitan model after every vLLM update. Testing shows this is unnecessary - bitwise determinism is maintained without it. Simplifies the training loop by removing the extra disk I/O and weight conversion overhead on each RL step.
Use vLLM's collective_rpc API to reload weights without recreating the
entire engine. This provides significant performance improvements:
- Weight reload: ~0.7-0.9s (vs ~7-10s for full engine recreation)
- Preserves KV cache, kernels, and memory allocations
- Reduces memory fragmentation
Changes:
- Update VLLMRolloutEngine.update_weights() to use
collective_rpc("reload_weights") instead of recreating engine
The reload mechanism saves updated weights to disk, then calls
reload_weights() on all workers via RPC, maintaining bitwise
determinism while avoiding expensive engine recreation.
Note: Requires VLLM_ALLOW_INSECURE_SERIALIZATION=1 environment
variable for collective_rpc with custom functions.
- Track which parameters receive gradients during training - Print detailed gradient report on first step showing module-by-module breakdown - Add gradient statistics to training metrics (total norm, counts) - Log gradient info to TensorBoard for monitoring - Display gradient warnings if parameters are missing gradients This helps identify issues where model parameters aren't being updated during training.
The FFN (feed-forward) layers were not receiving gradients because vLLM's SiluAndMul activation doesn't have autograd support. This caused 56/282 parameters to not be updated during training. Changes: - Add SiluAndMulFunction autograd wrapper in batch_invariant_backward.py - Implement proper backward pass for silu(gate) * up operation - Export silu_and_mul_with_gradients() function - Update FeedForwardVLLMCompat to use gradient-enabled version - Remove raw SiluAndMul import that lacked gradient support Now all parameters including gate_up_proj and ffn_norm weights receive gradients and can be trained properly.
Remove temporary debugging code added for gradient checking: - Removed print_gradient_report() function - Removed gradient tracking code in rl_update_step() - Removed gradient-related console output - Removed gradient metrics from TensorBoard logging - Removed step parameter from rl_update_step() (no longer needed) The gradient issue has been fixed, so this debugging code is no longer needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we could just reuse https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/qwen3/model/state_dict_adapter.py
If to_hf / from_hf are not exactly what vllm needs, we should have to_vllm and from_vllm. cc @wwwjn
| titan_state = vllm_compat_to_torchtitan(vllm_compat_state) | ||
| vllm_state = torchtitan_to_vllm(titan_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is confusing lol. Ideally we could make this minimal, by running torchtitan model on vLLM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm is far too optimized for this to be useful in a general context. we don't get free speedups from using vllm (its not a compiler), we get speedups from vllm building and maintaining fast versions of models.
it can be demonstrated (here) but it's very messy to do and I don't think it will scale with new features
| ) | ||
|
|
||
| # Wrap Flash Attention with manual backward pass | ||
| class FlashAttnWithBackward(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still have this question, lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
happy to stamp, exciting work!
I would recommend adding a more notable message in README, about the "generic model approach + parallelism" we are about to work on as next steps.
| ) | ||
|
|
||
| # Wrap Flash Attention with manual backward pass | ||
| class FlashAttnWithBackward(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a TODO to convert this op to a custom op, which has a better composability with torch.compile. https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial
| input_batch: torch.Tensor, | ||
| tokenizer: BaseTokenizer, | ||
| extra_inputs: dict[str, torch.Tensor] | None = None, | ||
| ) -> AttentionMasksType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ) -> AttentionMasksType: | |
| ) -> AttentionMasksType | None: |
This experiment provides a complete framework for bitwise-deterministic reinforcement learning training that combines vLLM for fast rollouts and TorchTitan for training with gradients.
Key features:
Components: