diff --git a/docs/advance/rollout_is_migration.md b/docs/advance/rollout_is.md similarity index 76% rename from docs/advance/rollout_is_migration.md rename to docs/advance/rollout_is.md index 6ba809ea372..8f75c803a5c 100644 --- a/docs/advance/rollout_is_migration.md +++ b/docs/advance/rollout_is.md @@ -1,13 +1,13 @@ -# Rollout Importance Sampling - Migration Guide +# Rollout Importance Sampling Last updated: 10/11/2025. -This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation merged from aiic_verl into verl. +This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation in verl. ## References -- **When Speed Kills Stability**: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda -- **Off-policy RL**: https://fengyao.notion.site/off-policy-rl +- [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) +- [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) ## Overview @@ -17,26 +17,10 @@ Rollout Importance Sampling corrects for distribution mismatch between: This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases. -## What Changed - -### **Removed (Old Implementation)** +## Configuration ```yaml -# Old TIS configuration (REMOVED) -actor: - tis_imp_ratio_cap: 2.0 # ❌ No longer supported -``` - -The old implementation: -- Only supported token-level truncate mode -- Had no metrics tracking -- Lacked numerical stability safeguards -- No configurability for different scenarios - -### **Added (New Implementation)** - -```yaml -# New Rollout IS configuration (all in algorithm config) +# Rollout IS configuration (all in algorithm config) algorithm: # Main control: set threshold to enable (null = disabled) rollout_is_threshold: 2.0 @@ -53,7 +37,7 @@ actor_rollout_ref: calculate_log_probs: true ``` -The new implementation: +Key features: - ✅ Three aggregation levels: token, sequence, geometric - ✅ Two bounding modes: truncate, mask - ✅ Dual threshold support (upper/lower) @@ -62,64 +46,35 @@ The new implementation: - ✅ Log-space computation for numerical stability - ✅ Memory-efficient implementation -## Files Modified +## Files ### **Core Implementation** -1. **NEW**: `verl/trainer/ppo/mismatch_helper.py` - - Contains `compute_rollout_importance_weights()` - main function - - Contains `compute_is_metrics()` - comprehensive metrics - -2. **MODIFIED**: `verl/trainer/ppo/core_algos.py` (lines 962-991) - - Replaced old TIS implementation (lines 962-967) - - Added new rollout IS with metrics support - -3. **MODIFIED**: `verl/workers/actor/dp_actor.py` - - Updated to use `rollout_is_threshold` instead of `tis_imp_ratio_cap` - - Collects and logs all rollout IS metrics +- `verl/trainer/ppo/mismatch_helper.py` - Contains `compute_rollout_importance_weights()` and `compute_is_metrics()` +- `verl/trainer/ppo/core_algos.py` - Rollout IS integration with PPO +- `verl/workers/actor/dp_actor.py` - Metrics collection and logging ### **Configuration Files** -4. **MODIFIED**: `verl/trainer/config/algorithm.py` (lines 95-100) - - Added 6 new rollout IS parameters to `AlgoConfig` - -5. **MODIFIED**: `verl/workers/config/actor.py` (lines 110-115) - - Added 6 new rollout IS parameters to `ActorConfig` - -6. **MODIFIED**: `verl/trainer/config/actor/actor.yaml` (lines 77-89) - - Added rollout IS configuration section - -7. **MODIFIED**: `verl/trainer/config/ppo_trainer.yaml` (lines 116-133) - - Added rollout IS to algorithm config +- `verl/trainer/config/algorithm.py` - Rollout IS parameters in `AlgoConfig` +- `verl/workers/config/actor.py` - Rollout IS parameters in `ActorConfig` +- `verl/trainer/config/actor/actor.yaml` - Rollout IS configuration section +- `verl/trainer/config/ppo_trainer.yaml` - Algorithm config with rollout IS ### **Documentation** -8. **MODIFIED**: `docs/examples/config.rst` - - Updated actor config with rollout IS parameters - - Updated algorithm config with rollout IS parameters - - Added detailed parameter descriptions +- `docs/examples/config.rst` - Configuration parameter descriptions ### **Example Scripts** -9. **MODIFIED**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh` - - Updated from `tis_imp_ratio_cap` to rollout IS parameters - - Added comprehensive comments - -10. **NEW**: `examples/rollout_importance_sampling/README.md` - - Comprehensive guide with usage patterns - - Troubleshooting section - - Performance considerations - -11. **NEW**: `examples/rollout_importance_sampling/run_with_rollout_is.sh` - - Basic example with token-level truncate +- `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh` - DAPO example with rollout IS +- `examples/rollout_importance_sampling/README.md` - Comprehensive usage guide +- `examples/rollout_importance_sampling/run_with_rollout_is.sh` - Basic example ### **Tests** -12. **NEW**: `tests/trainer/ppo/test_rollout_is.py` - - Unit tests for rollout IS functionality - -13. **NEW**: `tests/trainer/ppo/test_rollout_is_integration.py` - - Integration tests with PPO +- `tests/trainer/ppo/test_rollout_is.py` - Unit tests +- `tests/trainer/ppo/test_rollout_is_integration.py` - Integration tests ## Configuration Parameters @@ -156,20 +111,10 @@ Bounding mode: Per-token veto threshold. If any token ratio < this, entire sequence is rejected. Default: `1e-4` (ratio 10,000x off) -## Migration Steps - -### Step 1: Update Your Configuration +## Usage -**Before (Old):** -```yaml -actor_rollout_ref: - actor: - tis_imp_ratio_cap: 2.0 - rollout: - calculate_log_probs: true -``` +### Basic Setup -**After (New):** ```yaml algorithm: rollout_is_threshold: 2.0 # Main control @@ -179,10 +124,10 @@ algorithm: actor_rollout_ref: rollout: - calculate_log_probs: true # Still required! + calculate_log_probs: true # Required! ``` -### Step 2: Monitor New Metrics +### Metrics All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appears as `mismatch/rollout_is_mean` in logs. @@ -203,17 +148,6 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear - Shows the most overweighted token/sequence - Compare with `rollout_is_threshold` to see truncation impact -#### **Percentile Metrics** - -- **`rollout_is_p25`**: 25th percentile of IS weights -- **`rollout_is_p50`**: Median IS weight (50th percentile) - - Should be close to `rollout_is_mean` if distribution is symmetric -- **`rollout_is_p75`**: 75th percentile of IS weights -- **`rollout_is_p95`**: 95th percentile of IS weights - - Use to detect outliers -- **`rollout_is_p99`**: 99th percentile of IS weights - - Should be close to `rollout_is_threshold` if truncation is working - #### **Effective Sample Size** - **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting @@ -427,7 +361,7 @@ if not is_healthy: print(" - Checking if rollout and training policies are too different") ``` -### Step 3: Test Your Training +### Running Examples Start with the basic token-level truncate configuration: ```bash @@ -536,29 +470,23 @@ def plot_is_metrics(metrics_history): axes[0, 2].set_xlabel('Step') axes[0, 2].legend() - # Plot 4: IS weight distribution (latest step) - latest_idx = -1 - percentiles = [25, 50, 75, 95, 99] - values = [metrics_history[f'mismatch/rollout_is_p{p}'][latest_idx] for p in percentiles] - axes[1, 0].bar([f'p{p}' for p in percentiles], values) - axes[1, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal') - axes[1, 0].set_title('IS Weight Percentiles (Latest)') + # Plot 4: KL divergence over time + axes[1, 0].plot(metrics_history['mismatch/mismatch_kl'], label='KL') + axes[1, 0].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL') + axes[1, 0].axhline(y=0, color='g', linestyle='--', alpha=0.3) + axes[1, 0].set_title('KL Divergence') + axes[1, 0].set_xlabel('Step') axes[1, 0].legend() - # Plot 5: KL divergence over time - axes[1, 1].plot(metrics_history['mismatch/mismatch_kl'], label='KL') - axes[1, 1].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL') - axes[1, 1].axhline(y=0, color='g', linestyle='--', alpha=0.3) - axes[1, 1].set_title('KL Divergence') + # Plot 5: PPL ratio over time + axes[1, 1].plot(metrics_history['mismatch/mismatch_ppl_ratio']) + axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='Ideal') + axes[1, 1].set_title('PPL Ratio (Training/Rollout)') axes[1, 1].set_xlabel('Step') axes[1, 1].legend() - # Plot 6: PPL ratio over time - axes[1, 2].plot(metrics_history['mismatch/mismatch_ppl_ratio']) - axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Ideal') - axes[1, 2].set_title('PPL Ratio (Training/Rollout)') - axes[1, 2].set_xlabel('Step') - axes[1, 2].legend() + # Hide unused subplot + axes[1, 2].axis('off') plt.tight_layout() plt.savefig('rollout_is_metrics.png', dpi=150) @@ -573,11 +501,6 @@ metrics_history = { 'mismatch/rollout_is_mean': [], 'mismatch/rollout_is_eff_sample_size': [], 'mismatch/rollout_is_veto_fraction': [], - 'mismatch/rollout_is_p25': [], - 'mismatch/rollout_is_p50': [], - 'mismatch/rollout_is_p75': [], - 'mismatch/rollout_is_p95': [], - 'mismatch/rollout_is_p99': [], 'mismatch/mismatch_kl': [], 'mismatch/mismatch_k3_kl': [], 'mismatch/mismatch_ppl_ratio': [], @@ -604,11 +527,6 @@ for step in range(num_steps): - **Computational overhead**: 1-3% depending on level - **Training stability**: Significantly improved when mismatch exists -## Backward Compatibility - -**The old `tis_imp_ratio_cap` parameter is completely removed.** There is no backward compatibility mode. - -All scripts and configurations must be updated to use the new rollout IS parameters. ## Testing @@ -628,15 +546,13 @@ Expected output: All tests pass ✓ - **Implementation**: `verl/trainer/ppo/mismatch_helper.py` - **Examples**: `examples/rollout_importance_sampling/` -- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh` +- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh` ## Summary -The new Rollout Importance Sampling implementation provides: -- ✅ More robust handling of distribution mismatch -- ✅ Better numerical stability +Rollout Importance Sampling provides: +- ✅ Robust handling of distribution mismatch +- ✅ Numerical stability - ✅ Comprehensive metrics for monitoring - ✅ Flexibility for different scenarios - ✅ Memory-efficient computation - -Migration is straightforward: replace `tis_imp_ratio_cap` with the new `rollout_is_*` parameters in the `algorithm` config section. diff --git a/docs/index.rst b/docs/index.rst index e8467dc965a..382cbb0edd2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -121,7 +121,7 @@ verl is fast with: examples/sandbox_fusion_example advance/rollout_trace.rst advance/rollout_skip.rst - advance/rollout_is_migration.md + advance/rollout_is.md advance/one_step_off advance/agent_loop advance/fully_async diff --git a/examples/rollout_importance_sampling/README.md b/examples/rollout_importance_sampling/README.md index 38132d0e81a..1d0ce66394e 100644 --- a/examples/rollout_importance_sampling/README.md +++ b/examples/rollout_importance_sampling/README.md @@ -133,7 +133,6 @@ Key metrics to watch (all prefixed with `mismatch/` in logs): ### Distribution Metrics - `rollout_is_max`, `rollout_is_min`: Weight extremes - `rollout_is_std`: Standard deviation -- `rollout_is_p50`, `rollout_is_p95`, `rollout_is_p99`: Percentiles ### Diagnostic Metrics - `rollout_is_ratio_fraction_high`: Fraction exceeding upper threshold diff --git a/recipe/dapo/run_dapo_qwen2.5_32b_tis.sh b/recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh similarity index 95% rename from recipe/dapo/run_dapo_qwen2.5_32b_tis.sh rename to recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh index 21190223302..bec96da90ed 100644 --- a/recipe/dapo/run_dapo_qwen2.5_32b_tis.sh +++ b/recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh @@ -16,13 +16,13 @@ kl_coef=0.0 use_kl_loss=False kl_loss_coef=0.0 -# Rollout Importance Sampling parameters (matches original TIS with threshold=2) +# Rollout Importance Sampling parameters rollout_is=True rollout_is_threshold=2.0 -rollout_is_threshold_lower=null # No lower bound (original TIS behavior) -rollout_is_level=token # token-level (original TIS behavior) -rollout_is_mode=truncate # truncate mode (original TIS behavior) -rollout_is_veto_threshold=null # No veto (original TIS behavior) +rollout_is_threshold_lower=null # No lower bound +rollout_is_level=token # token-level +rollout_is_mode=truncate # truncate mode +rollout_is_veto_threshold=null # No veto clip_ratio_low=0.2 clip_ratio_high=0.28 diff --git a/tests/trainer/ppo/test_rollout_is.py b/tests/trainer/ppo/test_rollout_is.py index 2627e832a4c..584dec5a8dd 100644 --- a/tests/trainer/ppo/test_rollout_is.py +++ b/tests/trainer/ppo/test_rollout_is.py @@ -47,7 +47,7 @@ def test_basic_rollout_is(): rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.1 eos_mask = torch.ones(batch_size, seq_length, device=device) - # Test token-level truncate mode (equivalent to old TIS) + # Test token-level truncate mode print("\n1. Testing token-level truncate mode...") weights_proto, metrics = compute_rollout_importance_weights( old_log_prob=old_log_prob, @@ -180,11 +180,6 @@ def test_metrics_completeness(): "mismatch/rollout_is_catastrophic_token_fraction", "mismatch/rollout_is_ratio_fraction_high", "mismatch/rollout_is_ratio_fraction_low", - "mismatch/rollout_is_p25", - "mismatch/rollout_is_p50", - "mismatch/rollout_is_p75", - "mismatch/rollout_is_p95", - "mismatch/rollout_is_p99", ] # Expected mismatch/diagnostic metrics (also included now) diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py index 2f672d86fc3..859ce90ff32 100644 --- a/verl/trainer/config/algorithm.py +++ b/verl/trainer/config/algorithm.py @@ -93,7 +93,7 @@ class AlgoConfig(BaseConfig): use_pf_ppo: bool = False pf_ppo: dict[str, Any] = field(default_factory=dict) filter_groups: Optional[FilterGroupsConfig] = None - # Rollout Importance Sampling (replaces legacy tis_imp_ratio_cap) + # Rollout Importance Sampling # Controls computation of IS weights and mismatch metrics rollout_is_threshold: Optional[float] = None # null = disabled, float = enabled rollout_is_threshold_lower: Optional[float] = None diff --git a/verl/trainer/ppo/mismatch_helper.py b/verl/trainer/ppo/mismatch_helper.py index 48d0696c845..78bb5d6accb 100644 --- a/verl/trainer/ppo/mismatch_helper.py +++ b/verl/trainer/ppo/mismatch_helper.py @@ -20,7 +20,7 @@ Key Features: 1. Three aggregation levels: token, sequence, geometric -2. Two handling modes: truncate (TIS), mask (MIS) +2. Two handling modes: truncate, mask 3. Per-token veto mechanism for catastrophic outliers 4. Memory-efficient computation to prevent CUDA OOM 5. Comprehensive metrics tracking @@ -76,8 +76,8 @@ def compute_rollout_importance_weights( - "sequence": Product of ratios (unbiased) - "geometric": Geometric mean of ratios (experimental) rollout_is_mode: How to handle weights exceeding threshold: - - "truncate": Cap weights at upper_threshold only (TIS) - - "mask": Zero out weights outside [lower_threshold, upper_threshold] (MIS) + - "truncate": Cap weights at upper_threshold only + - "mask": Zero out weights outside [lower_threshold, upper_threshold] rollout_is_threshold: Upper threshold for IS weights rollout_is_threshold_lower: Lower threshold for IS weights (mask mode only; if None, defaults to 1/upper) rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence. @@ -181,11 +181,11 @@ def compute_rollout_importance_weights( # Step 3: Apply truncation or masking based on mode if rollout_is_mode == "truncate": - # Truncated IS (TIS): only cap upper bound to prevent overweighting + # Truncate mode: only cap upper bound to prevent overweighting rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold) elif rollout_is_mode == "mask": - # Masked IS (MIS): zero out weights outside [lower_threshold, upper_threshold] + # Mask mode: zero out weights outside [lower_threshold, upper_threshold] mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold) mask = mask.float() @@ -355,17 +355,6 @@ def compute_is_metrics( metrics["rollout_is_seq_fraction_high"] = (seq_mean_weights > rollout_is_threshold).float().mean() metrics["rollout_is_seq_fraction_low"] = (seq_mean_weights < rollout_is_threshold_lower).float().mean() - # Percentile metrics for better distribution understanding - # Get all valid IS weights - flat_weights = rollout_is_weights[response_mask.bool()] - # Compute key percentiles (guaranteed to have elements due to assertion at function start) - assert flat_weights.numel() > 0, "flat_weights should not be empty" - metrics["rollout_is_p25"] = torch.quantile(flat_weights, 0.25) - metrics["rollout_is_p50"] = torch.quantile(flat_weights, 0.50) # median - metrics["rollout_is_p75"] = torch.quantile(flat_weights, 0.75) - metrics["rollout_is_p95"] = torch.quantile(flat_weights, 0.95) - metrics["rollout_is_p99"] = torch.quantile(flat_weights, 0.99) - return metrics