Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 42 additions & 126 deletions docs/advance/rollout_is_migration.md → docs/advance/rollout_is.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Rollout Importance Sampling - Migration Guide
# Rollout Importance Sampling

Last updated: 10/11/2025.

This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation merged from aiic_verl into verl.
This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation in verl.

## References

- **When Speed Kills Stability**: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
- **Off-policy RL**: https://fengyao.notion.site/off-policy-rl
- [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda)
- [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl)

## Overview

Expand All @@ -17,26 +17,10 @@ Rollout Importance Sampling corrects for distribution mismatch between:

This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases.

## What Changed

### **Removed (Old Implementation)**
## Configuration

```yaml
# Old TIS configuration (REMOVED)
actor:
tis_imp_ratio_cap: 2.0 # ❌ No longer supported
```

The old implementation:
- Only supported token-level truncate mode
- Had no metrics tracking
- Lacked numerical stability safeguards
- No configurability for different scenarios

### **Added (New Implementation)**

```yaml
# New Rollout IS configuration (all in algorithm config)
# Rollout IS configuration (all in algorithm config)
algorithm:
# Main control: set threshold to enable (null = disabled)
rollout_is_threshold: 2.0
Expand All @@ -53,7 +37,7 @@ actor_rollout_ref:
calculate_log_probs: true
```

The new implementation:
Key features:
- ✅ Three aggregation levels: token, sequence, geometric
- ✅ Two bounding modes: truncate, mask
- ✅ Dual threshold support (upper/lower)
Expand All @@ -62,64 +46,35 @@ The new implementation:
- ✅ Log-space computation for numerical stability
- ✅ Memory-efficient implementation

## Files Modified
## Files

### **Core Implementation**

1. **NEW**: `verl/trainer/ppo/mismatch_helper.py`
- Contains `compute_rollout_importance_weights()` - main function
- Contains `compute_is_metrics()` - comprehensive metrics

2. **MODIFIED**: `verl/trainer/ppo/core_algos.py` (lines 962-991)
- Replaced old TIS implementation (lines 962-967)
- Added new rollout IS with metrics support

3. **MODIFIED**: `verl/workers/actor/dp_actor.py`
- Updated to use `rollout_is_threshold` instead of `tis_imp_ratio_cap`
- Collects and logs all rollout IS metrics
- `verl/trainer/ppo/mismatch_helper.py` - Contains `compute_rollout_importance_weights()` and `compute_is_metrics()`
- `verl/trainer/ppo/core_algos.py` - Rollout IS integration with PPO
- `verl/workers/actor/dp_actor.py` - Metrics collection and logging

### **Configuration Files**

4. **MODIFIED**: `verl/trainer/config/algorithm.py` (lines 95-100)
- Added 6 new rollout IS parameters to `AlgoConfig`

5. **MODIFIED**: `verl/workers/config/actor.py` (lines 110-115)
- Added 6 new rollout IS parameters to `ActorConfig`

6. **MODIFIED**: `verl/trainer/config/actor/actor.yaml` (lines 77-89)
- Added rollout IS configuration section

7. **MODIFIED**: `verl/trainer/config/ppo_trainer.yaml` (lines 116-133)
- Added rollout IS to algorithm config
- `verl/trainer/config/algorithm.py` - Rollout IS parameters in `AlgoConfig`
- `verl/workers/config/actor.py` - Rollout IS parameters in `ActorConfig`
- `verl/trainer/config/actor/actor.yaml` - Rollout IS configuration section
- `verl/trainer/config/ppo_trainer.yaml` - Algorithm config with rollout IS

### **Documentation**

8. **MODIFIED**: `docs/examples/config.rst`
- Updated actor config with rollout IS parameters
- Updated algorithm config with rollout IS parameters
- Added detailed parameter descriptions
- `docs/examples/config.rst` - Configuration parameter descriptions

### **Example Scripts**

9. **MODIFIED**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
- Updated from `tis_imp_ratio_cap` to rollout IS parameters
- Added comprehensive comments

10. **NEW**: `examples/rollout_importance_sampling/README.md`
- Comprehensive guide with usage patterns
- Troubleshooting section
- Performance considerations

11. **NEW**: `examples/rollout_importance_sampling/run_with_rollout_is.sh`
- Basic example with token-level truncate
- `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh` - DAPO example with rollout IS
- `examples/rollout_importance_sampling/README.md` - Comprehensive usage guide
- `examples/rollout_importance_sampling/run_with_rollout_is.sh` - Basic example

### **Tests**

12. **NEW**: `tests/trainer/ppo/test_rollout_is.py`
- Unit tests for rollout IS functionality

13. **NEW**: `tests/trainer/ppo/test_rollout_is_integration.py`
- Integration tests with PPO
- `tests/trainer/ppo/test_rollout_is.py` - Unit tests
- `tests/trainer/ppo/test_rollout_is_integration.py` - Integration tests

## Configuration Parameters

Expand Down Expand Up @@ -156,20 +111,10 @@ Bounding mode:
Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
Default: `1e-4` (ratio 10,000x off)

## Migration Steps

### Step 1: Update Your Configuration
## Usage

**Before (Old):**
```yaml
actor_rollout_ref:
actor:
tis_imp_ratio_cap: 2.0
rollout:
calculate_log_probs: true
```
### Basic Setup

**After (New):**
```yaml
algorithm:
rollout_is_threshold: 2.0 # Main control
Expand All @@ -179,10 +124,10 @@ algorithm:

actor_rollout_ref:
rollout:
calculate_log_probs: true # Still required!
calculate_log_probs: true # Required!
```

### Step 2: Monitor New Metrics
### Metrics

All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appears as `mismatch/rollout_is_mean` in logs.

Expand All @@ -203,17 +148,6 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
- Shows the most overweighted token/sequence
- Compare with `rollout_is_threshold` to see truncation impact

#### **Percentile Metrics**

- **`rollout_is_p25`**: 25th percentile of IS weights
- **`rollout_is_p50`**: Median IS weight (50th percentile)
- Should be close to `rollout_is_mean` if distribution is symmetric
- **`rollout_is_p75`**: 75th percentile of IS weights
- **`rollout_is_p95`**: 95th percentile of IS weights
- Use to detect outliers
- **`rollout_is_p99`**: 99th percentile of IS weights
- Should be close to `rollout_is_threshold` if truncation is working

#### **Effective Sample Size**

- **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting
Expand Down Expand Up @@ -427,7 +361,7 @@ if not is_healthy:
print(" - Checking if rollout and training policies are too different")
```

### Step 3: Test Your Training
### Running Examples

Start with the basic token-level truncate configuration:
```bash
Expand Down Expand Up @@ -536,29 +470,23 @@ def plot_is_metrics(metrics_history):
axes[0, 2].set_xlabel('Step')
axes[0, 2].legend()

# Plot 4: IS weight distribution (latest step)
latest_idx = -1
percentiles = [25, 50, 75, 95, 99]
values = [metrics_history[f'mismatch/rollout_is_p{p}'][latest_idx] for p in percentiles]
axes[1, 0].bar([f'p{p}' for p in percentiles], values)
axes[1, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 0].set_title('IS Weight Percentiles (Latest)')
# Plot 4: KL divergence over time
axes[1, 0].plot(metrics_history['mismatch/mismatch_kl'], label='KL')
axes[1, 0].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')
axes[1, 0].axhline(y=0, color='g', linestyle='--', alpha=0.3)
axes[1, 0].set_title('KL Divergence')
axes[1, 0].set_xlabel('Step')
axes[1, 0].legend()

# Plot 5: KL divergence over time
axes[1, 1].plot(metrics_history['mismatch/mismatch_kl'], label='KL')
axes[1, 1].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')
axes[1, 1].axhline(y=0, color='g', linestyle='--', alpha=0.3)
axes[1, 1].set_title('KL Divergence')
# Plot 5: PPL ratio over time
axes[1, 1].plot(metrics_history['mismatch/mismatch_ppl_ratio'])
axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 1].set_title('PPL Ratio (Training/Rollout)')
axes[1, 1].set_xlabel('Step')
axes[1, 1].legend()

# Plot 6: PPL ratio over time
axes[1, 2].plot(metrics_history['mismatch/mismatch_ppl_ratio'])
axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 2].set_title('PPL Ratio (Training/Rollout)')
axes[1, 2].set_xlabel('Step')
axes[1, 2].legend()
# Hide unused subplot
axes[1, 2].axis('off')

plt.tight_layout()
plt.savefig('rollout_is_metrics.png', dpi=150)
Expand All @@ -573,11 +501,6 @@ metrics_history = {
'mismatch/rollout_is_mean': [],
'mismatch/rollout_is_eff_sample_size': [],
'mismatch/rollout_is_veto_fraction': [],
'mismatch/rollout_is_p25': [],
'mismatch/rollout_is_p50': [],
'mismatch/rollout_is_p75': [],
'mismatch/rollout_is_p95': [],
'mismatch/rollout_is_p99': [],
'mismatch/mismatch_kl': [],
'mismatch/mismatch_k3_kl': [],
'mismatch/mismatch_ppl_ratio': [],
Expand All @@ -604,11 +527,6 @@ for step in range(num_steps):
- **Computational overhead**: 1-3% depending on level
- **Training stability**: Significantly improved when mismatch exists

## Backward Compatibility

**The old `tis_imp_ratio_cap` parameter is completely removed.** There is no backward compatibility mode.

All scripts and configurations must be updated to use the new rollout IS parameters.

## Testing

Expand All @@ -628,15 +546,13 @@ Expected output: All tests pass ✓

- **Implementation**: `verl/trainer/ppo/mismatch_helper.py`
- **Examples**: `examples/rollout_importance_sampling/`
- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh`

## Summary

The new Rollout Importance Sampling implementation provides:
- ✅ More robust handling of distribution mismatch
- ✅ Better numerical stability
Rollout Importance Sampling provides:
- ✅ Robust handling of distribution mismatch
- ✅ Numerical stability
- ✅ Comprehensive metrics for monitoring
- ✅ Flexibility for different scenarios
- ✅ Memory-efficient computation

Migration is straightforward: replace `tis_imp_ratio_cap` with the new `rollout_is_*` parameters in the `algorithm` config section.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ verl is fast with:
examples/sandbox_fusion_example
advance/rollout_trace.rst
advance/rollout_skip.rst
advance/rollout_is_migration.md
advance/rollout_is.md
advance/one_step_off
advance/agent_loop
advance/fully_async
Expand Down
1 change: 0 additions & 1 deletion examples/rollout_importance_sampling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ Key metrics to watch (all prefixed with `mismatch/` in logs):
### Distribution Metrics
- `rollout_is_max`, `rollout_is_min`: Weight extremes
- `rollout_is_std`: Standard deviation
- `rollout_is_p50`, `rollout_is_p95`, `rollout_is_p99`: Percentiles

### Diagnostic Metrics
- `rollout_is_ratio_fraction_high`: Fraction exceeding upper threshold
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

# Rollout Importance Sampling parameters (matches original TIS with threshold=2)
# Rollout Importance Sampling parameters
rollout_is=True
rollout_is_threshold=2.0
rollout_is_threshold_lower=null # No lower bound (original TIS behavior)
rollout_is_level=token # token-level (original TIS behavior)
rollout_is_mode=truncate # truncate mode (original TIS behavior)
rollout_is_veto_threshold=null # No veto (original TIS behavior)
rollout_is_threshold_lower=null # No lower bound
rollout_is_level=token # token-level
rollout_is_mode=truncate # truncate mode
rollout_is_veto_threshold=null # No veto

clip_ratio_low=0.2
clip_ratio_high=0.28
Expand Down
7 changes: 1 addition & 6 deletions tests/trainer/ppo/test_rollout_is.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_basic_rollout_is():
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.1
eos_mask = torch.ones(batch_size, seq_length, device=device)

# Test token-level truncate mode (equivalent to old TIS)
# Test token-level truncate mode
print("\n1. Testing token-level truncate mode...")
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
Expand Down Expand Up @@ -180,11 +180,6 @@ def test_metrics_completeness():
"mismatch/rollout_is_catastrophic_token_fraction",
"mismatch/rollout_is_ratio_fraction_high",
"mismatch/rollout_is_ratio_fraction_low",
"mismatch/rollout_is_p25",
"mismatch/rollout_is_p50",
"mismatch/rollout_is_p75",
"mismatch/rollout_is_p95",
"mismatch/rollout_is_p99",
]

# Expected mismatch/diagnostic metrics (also included now)
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class AlgoConfig(BaseConfig):
use_pf_ppo: bool = False
pf_ppo: dict[str, Any] = field(default_factory=dict)
filter_groups: Optional[FilterGroupsConfig] = None
# Rollout Importance Sampling (replaces legacy tis_imp_ratio_cap)
# Rollout Importance Sampling
# Controls computation of IS weights and mismatch metrics
rollout_is_threshold: Optional[float] = None # null = disabled, float = enabled
rollout_is_threshold_lower: Optional[float] = None
Expand Down
Loading
Loading