Skip to content

Conversation

@bwasti
Copy link

@bwasti bwasti commented Oct 31, 2025

This experiment provides a complete framework for bitwise-deterministic reinforcement learning training that combines vLLM for fast rollouts and TorchTitan for training with gradients.

Key features:

  • Bitwise-deterministic forward and backward passes
  • vLLM-compatible Qwen3 model with merged projections
  • Custom Flash Attention with gradient support
  • Gradient support for vLLM's batch-invariant operations
  • Complete RL training loop with GRPO-style advantages
  • Comprehensive test suite for determinism verification

Components:

  • models/attention.py: VLLMCompatibleFlashAttention
  • models/qwen3/model_vllm_compat.py: vLLM-compatible Qwen3 model
  • batch_invariant_backward.py: Gradient support for vLLM operations
  • simple_rl.py: End-to-end RL training loop
  • tests/: Test suite for backward passes and determinism

This experiment provides a complete framework for bitwise-deterministic
reinforcement learning training that combines vLLM for fast rollouts and
TorchTitan for training with gradients.

Key features:
- Bitwise-deterministic forward and backward passes
- vLLM-compatible Qwen3 model with merged projections
- Custom Flash Attention with gradient support
- Gradient support for vLLM's batch-invariant operations
- Complete RL training loop with GRPO-style advantages
- Comprehensive test suite for determinism verification

Components:
- models/attention.py: VLLMCompatibleFlashAttention
- models/qwen3/model_vllm_compat.py: vLLM-compatible Qwen3 model
- batch_invariant_backward.py: Gradient support for vLLM operations
- simple_rl.py: End-to-end RL training loop
- tests/: Test suite for backward passes and determinism
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 31, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I left some questions and comments.


### Performance

- **Rollout speed**: ~100x faster than standard PyTorch (thanks to vLLM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 lol

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a totally random number lmao -- kv cache does all the work. thanks claude for adding this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help add the disclaimer to this file too?

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar, and also the two test files

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we mention that this only works for single device right now, and we plan to extend it to work with parallelisms?

@staticmethod
def backward(ctx, grad_output):
"""
Backward pass using batch-invariant PyTorch operations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Batch-invariant backward sounds nice! But I guess batch-invariance doesn't matter too much for backward? E.g. you'll lose batch invariance when doing loss reduction?

Asking because the decomposed implementation below might be slower than a non-batch-invariant kernel.

import numpy as np
from torch.utils.tensorboard import SummaryWriter

import torchtitan.experiments.compat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used and will error?


import torchtitan.experiments.compat
from torchtitan.models.qwen3.model.args import Qwen3ModelArgs
from weights_vllm_compat import torchtitan_to_vllm_compat, vllm_compat_to_torchtitan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These files are missing?

import torchtitan.experiments.compat
from torchtitan.models.qwen3.model.args import Qwen3ModelArgs
from weights_vllm_compat import torchtitan_to_vllm_compat, vllm_compat_to_torchtitan
from weights.converter import torchtitan_to_vllm, vllm_to_torchtitan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

import glob
from safetensors.torch import load_file as sf_load
from weights.converter import vllm_to_torchtitan
from weights_vllm_compat import torchtitan_to_vllm_compat as titan_to_vllm_compat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems the same as few lines above

Comment on lines 632 to 651
shard_files = sorted(glob.glob(os.path.join(vllm_engine.temp_model_dir, "model-*.safetensors")))
if shard_files:
# Load all shards
all_disk_state = {}
for shard_file in shard_files:
shard_state = sf_load(shard_file)
all_disk_state.update(shard_state)

# Convert back to TorchTitan format
titan_from_disk = vllm_to_torchtitan(all_disk_state)

if use_vllm_compat:
# Convert to vLLM-compat format for vLLM-compatible model
weights_for_model = titan_to_vllm_compat(titan_from_disk)
else:
# Use standard TorchTitan format for standard model
weights_for_model = titan_from_disk

# Load back into model
model.load_state_dict(weights_for_model, strict=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit confusing. I thought loading from disk should be done only once, instead of in each rl_update_step. Is it related to the checkpoint bug you mentioned earlier?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I believe this is done to because we suspected that there may be some rounding issues in weights transfer but i verified removing this code and it works fine. my experience on checkpointing, unless there is a dtype conversion, weight loading should always be bitwise in sync. so, i will send an update removing this.

This commit adds the missing weight conversion utilities that are required
by Bram's base commit (simple_rl.py imports them but they were missing):

- weights_vllm_compat.py: Converts between TorchTitan and vLLM-compat formats
  (merges/splits gate_up_proj for FFN layers)
- weights/converter.py: Converts between vLLM HuggingFace and TorchTitan formats
- weights/__init__.py: Package init
- weights/README.md: Documentation for weight converters

Import fixes:
- simple_rl.py: Use local models.qwen3 instead of torchtitan.models.qwen3.model
- model_vllm_compat.py: Import VLLMCompatibleFlashAttention from local ..attention
  and Qwen3ModelArgs from torchtitan.models.qwen3.model.args
- Add BSD-style license headers to all new files:
  * batch_invariant_backward.py
  * simple_rl.py
  * tests/test_batch_invariant_backward.py
  * tests/test_exact_determinism.py
  * weights_vllm_compat.py
  * weights/converter.py
  * weights/__init__.py

- Add note about single-device limitation in README.md
  Currently supports single-device training only; future work will
  extend to distributed training with parallelism

- Remove unused imports in simple_rl.py:
  * Remove 'import torchtitan.experiments.compat' (unused)
  * Remove duplicate imports of torchtitan_to_vllm_compat

- Fix all imports to use absolute paths for python -m compatibility:
  * Update model_vllm_compat.py to import from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward
  * Update simple_rl.py to import from torchtitan.experiments.deterministic_vllm_rl modules
  * Removes sys.path manipulation - now works cleanly with python -m

- Remove duplicate RMSNormFunction from model_vllm_compat.py:
  * Import rms_norm_with_gradients from batch_invariant_backward.py
  * Remove duplicate RMSNormFunction class and function definition
  * Keeps gradient-enabled operations centralized in utilities module
The round-trip weight loading code was reloading weights from disk back
into the TorchTitan model after every vLLM update. Testing shows this
is unnecessary - bitwise determinism is maintained without it.

Simplifies the training loop by removing the extra disk I/O and weight
conversion overhead on each RL step.
Use vLLM's collective_rpc API to reload weights without recreating the
entire engine. This provides significant performance improvements:
- Weight reload: ~0.7-0.9s (vs ~7-10s for full engine recreation)
- Preserves KV cache, kernels, and memory allocations
- Reduces memory fragmentation

Changes:
- Update VLLMRolloutEngine.update_weights() to use
  collective_rpc("reload_weights") instead of recreating engine

The reload mechanism saves updated weights to disk, then calls
reload_weights() on all workers via RPC, maintaining bitwise
determinism while avoiding expensive engine recreation.

Note: Requires VLLM_ALLOW_INSECURE_SERIALIZATION=1 environment
variable for collective_rpc with custom functions.
- Track which parameters receive gradients during training
- Print detailed gradient report on first step showing module-by-module breakdown
- Add gradient statistics to training metrics (total norm, counts)
- Log gradient info to TensorBoard for monitoring
- Display gradient warnings if parameters are missing gradients

This helps identify issues where model parameters aren't being updated during training.
The FFN (feed-forward) layers were not receiving gradients because vLLM's
SiluAndMul activation doesn't have autograd support. This caused 56/282
parameters to not be updated during training.

Changes:
- Add SiluAndMulFunction autograd wrapper in batch_invariant_backward.py
- Implement proper backward pass for silu(gate) * up operation
- Export silu_and_mul_with_gradients() function
- Update FeedForwardVLLMCompat to use gradient-enabled version
- Remove raw SiluAndMul import that lacked gradient support

Now all parameters including gate_up_proj and ffn_norm weights receive
gradients and can be trained properly.
Remove temporary debugging code added for gradient checking:
- Removed print_gradient_report() function
- Removed gradient tracking code in rl_update_step()
- Removed gradient-related console output
- Removed gradient metrics from TensorBoard logging
- Removed step parameter from rl_update_step() (no longer needed)

The gradient issue has been fixed, so this debugging code is no longer needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could just reuse https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/qwen3/model/state_dict_adapter.py

If to_hf / from_hf are not exactly what vllm needs, we should have to_vllm and from_vllm. cc @wwwjn

Comment on lines 92 to 93
titan_state = vllm_compat_to_torchtitan(vllm_compat_state)
vllm_state = torchtitan_to_vllm(titan_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is confusing lol. Ideally we could make this minimal, by running torchtitan model on vLLM.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm is far too optimized for this to be useful in a general context. we don't get free speedups from using vllm (its not a compiler), we get speedups from vllm building and maintaining fast versions of models.

it can be demonstrated (here) but it's very messy to do and I don't think it will scale with new features

)

# Wrap Flash Attention with manual backward pass
class FlashAttnWithBackward(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still have this question, lol

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to stamp, exciting work!

I would recommend adding a more notable message in README, about the "generic model approach + parallelism" we are about to work on as next steps.

)

# Wrap Flash Attention with manual backward pass
class FlashAttnWithBackward(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a TODO to convert this op to a custom op, which has a better composability with torch.compile. https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial

input_batch: torch.Tensor,
tokenizer: BaseTokenizer,
extra_inputs: dict[str, torch.Tensor] | None = None,
) -> AttentionMasksType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> AttentionMasksType:
) -> AttentionMasksType | None:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants