Skip to content

Commit ffaf769

Browse files
szrleenoah
authored andcommitted
[algo] fix: remove torch.quantile-based percentile metrics to resolve tensor size limit error (volcengine#3810)
## Summary Fixes volcengine#3787 by removing `torch.quantile()`-based percentile metrics (`rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`) that caused `RuntimeError: quantile() input tensor is too large` when using large batch sizes or response lengths. ## Problem When using configurations with large tensor sizes (e.g., `max_response_length: 32k`, `rollout.n: 16`, `train_batch_size: 16`), the `torch.quantile()` function fails with a runtime error due to PyTorch's internal tensor size limitations (~2^24 to 2^27 elements depending on version, GPU memory, and dtype). The error occurred in `verl/trainer/ppo/mismatch_helper.py`: ```python metrics["rollout_is_p25"] = torch.quantile(flat_weights, 0.25) metrics["rollout_is_p50"] = torch.quantile(flat_weights, 0.50) metrics["rollout_is_p75"] = torch.quantile(flat_weights, 0.75) ``` ## Solution Removed the three quantile-based percentile metrics from the Rollout IS framework. The remaining metrics (`rollout_is_mean`, `rollout_is_std`, `rollout_is_min`, `rollout_is_max`, `rollout_is_eff_sample_size`, etc.) provide sufficient monitoring capabilities for importance sampling health without triggering tensor size limitations. ## Changes - **Modified**: [verl/trainer/ppo/mismatch_helper.py](verl/trainer/ppo/mismatch_helper.py) - Removed `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75` metric calculations - All other rollout IS and mismatch metrics remain functional ## Testing Verified that: - Rollout IS framework continues to function correctly without percentile metrics - No runtime errors with large tensor configurations - All other metrics (mean, std, min, max, ESS, veto fraction, etc.) are computed correctly Resolves volcengine#3787
1 parent 8074b6d commit ffaf769

File tree

7 files changed

+55
-156
lines changed

7 files changed

+55
-156
lines changed

docs/advance/rollout_is_migration.md renamed to docs/advance/rollout_is.md

Lines changed: 42 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
# Rollout Importance Sampling - Migration Guide
1+
# Rollout Importance Sampling
22

33
Last updated: 10/11/2025.
44

5-
This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation merged from aiic_verl into verl.
5+
This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation in verl.
66

77
## References
88

9-
- **When Speed Kills Stability**: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
10-
- **Off-policy RL**: https://fengyao.notion.site/off-policy-rl
9+
- [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda)
10+
- [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl)
1111

1212
## Overview
1313

@@ -17,26 +17,10 @@ Rollout Importance Sampling corrects for distribution mismatch between:
1717

1818
This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases.
1919

20-
## What Changed
21-
22-
### **Removed (Old Implementation)**
20+
## Configuration
2321

2422
```yaml
25-
# Old TIS configuration (REMOVED)
26-
actor:
27-
tis_imp_ratio_cap: 2.0 # ❌ No longer supported
28-
```
29-
30-
The old implementation:
31-
- Only supported token-level truncate mode
32-
- Had no metrics tracking
33-
- Lacked numerical stability safeguards
34-
- No configurability for different scenarios
35-
36-
### **Added (New Implementation)**
37-
38-
```yaml
39-
# New Rollout IS configuration (all in algorithm config)
23+
# Rollout IS configuration (all in algorithm config)
4024
algorithm:
4125
# Main control: set threshold to enable (null = disabled)
4226
rollout_is_threshold: 2.0
@@ -53,7 +37,7 @@ actor_rollout_ref:
5337
calculate_log_probs: true
5438
```
5539
56-
The new implementation:
40+
Key features:
5741
- ✅ Three aggregation levels: token, sequence, geometric
5842
- ✅ Two bounding modes: truncate, mask
5943
- ✅ Dual threshold support (upper/lower)
@@ -62,64 +46,35 @@ The new implementation:
6246
- ✅ Log-space computation for numerical stability
6347
- ✅ Memory-efficient implementation
6448
65-
## Files Modified
49+
## Files
6650
6751
### **Core Implementation**
6852
69-
1. **NEW**: `verl/trainer/ppo/mismatch_helper.py`
70-
- Contains `compute_rollout_importance_weights()` - main function
71-
- Contains `compute_is_metrics()` - comprehensive metrics
72-
73-
2. **MODIFIED**: `verl/trainer/ppo/core_algos.py` (lines 962-991)
74-
- Replaced old TIS implementation (lines 962-967)
75-
- Added new rollout IS with metrics support
76-
77-
3. **MODIFIED**: `verl/workers/actor/dp_actor.py`
78-
- Updated to use `rollout_is_threshold` instead of `tis_imp_ratio_cap`
79-
- Collects and logs all rollout IS metrics
53+
- `verl/trainer/ppo/mismatch_helper.py` - Contains `compute_rollout_importance_weights()` and `compute_is_metrics()`
54+
- `verl/trainer/ppo/core_algos.py` - Rollout IS integration with PPO
55+
- `verl/workers/actor/dp_actor.py` - Metrics collection and logging
8056

8157
### **Configuration Files**
8258

83-
4. **MODIFIED**: `verl/trainer/config/algorithm.py` (lines 95-100)
84-
- Added 6 new rollout IS parameters to `AlgoConfig`
85-
86-
5. **MODIFIED**: `verl/workers/config/actor.py` (lines 110-115)
87-
- Added 6 new rollout IS parameters to `ActorConfig`
88-
89-
6. **MODIFIED**: `verl/trainer/config/actor/actor.yaml` (lines 77-89)
90-
- Added rollout IS configuration section
91-
92-
7. **MODIFIED**: `verl/trainer/config/ppo_trainer.yaml` (lines 116-133)
93-
- Added rollout IS to algorithm config
59+
- `verl/trainer/config/algorithm.py` - Rollout IS parameters in `AlgoConfig`
60+
- `verl/workers/config/actor.py` - Rollout IS parameters in `ActorConfig`
61+
- `verl/trainer/config/actor/actor.yaml` - Rollout IS configuration section
62+
- `verl/trainer/config/ppo_trainer.yaml` - Algorithm config with rollout IS
9463

9564
### **Documentation**
9665

97-
8. **MODIFIED**: `docs/examples/config.rst`
98-
- Updated actor config with rollout IS parameters
99-
- Updated algorithm config with rollout IS parameters
100-
- Added detailed parameter descriptions
66+
- `docs/examples/config.rst` - Configuration parameter descriptions
10167

10268
### **Example Scripts**
10369

104-
9. **MODIFIED**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
105-
- Updated from `tis_imp_ratio_cap` to rollout IS parameters
106-
- Added comprehensive comments
107-
108-
10. **NEW**: `examples/rollout_importance_sampling/README.md`
109-
- Comprehensive guide with usage patterns
110-
- Troubleshooting section
111-
- Performance considerations
112-
113-
11. **NEW**: `examples/rollout_importance_sampling/run_with_rollout_is.sh`
114-
- Basic example with token-level truncate
70+
- `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh` - DAPO example with rollout IS
71+
- `examples/rollout_importance_sampling/README.md` - Comprehensive usage guide
72+
- `examples/rollout_importance_sampling/run_with_rollout_is.sh` - Basic example
11573

11674
### **Tests**
11775

118-
12. **NEW**: `tests/trainer/ppo/test_rollout_is.py`
119-
- Unit tests for rollout IS functionality
120-
121-
13. **NEW**: `tests/trainer/ppo/test_rollout_is_integration.py`
122-
- Integration tests with PPO
76+
- `tests/trainer/ppo/test_rollout_is.py` - Unit tests
77+
- `tests/trainer/ppo/test_rollout_is_integration.py` - Integration tests
12378

12479
## Configuration Parameters
12580

@@ -156,20 +111,10 @@ Bounding mode:
156111
Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
157112
Default: `1e-4` (ratio 10,000x off)
158113

159-
## Migration Steps
160-
161-
### Step 1: Update Your Configuration
114+
## Usage
162115

163-
**Before (Old):**
164-
```yaml
165-
actor_rollout_ref:
166-
actor:
167-
tis_imp_ratio_cap: 2.0
168-
rollout:
169-
calculate_log_probs: true
170-
```
116+
### Basic Setup
171117

172-
**After (New):**
173118
```yaml
174119
algorithm:
175120
rollout_is_threshold: 2.0 # Main control
@@ -179,10 +124,10 @@ algorithm:
179124
180125
actor_rollout_ref:
181126
rollout:
182-
calculate_log_probs: true # Still required!
127+
calculate_log_probs: true # Required!
183128
```
184129

185-
### Step 2: Monitor New Metrics
130+
### Metrics
186131

187132
All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appears as `mismatch/rollout_is_mean` in logs.
188133

@@ -203,17 +148,6 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
203148
- Shows the most overweighted token/sequence
204149
- Compare with `rollout_is_threshold` to see truncation impact
205150

206-
#### **Percentile Metrics**
207-
208-
- **`rollout_is_p25`**: 25th percentile of IS weights
209-
- **`rollout_is_p50`**: Median IS weight (50th percentile)
210-
- Should be close to `rollout_is_mean` if distribution is symmetric
211-
- **`rollout_is_p75`**: 75th percentile of IS weights
212-
- **`rollout_is_p95`**: 95th percentile of IS weights
213-
- Use to detect outliers
214-
- **`rollout_is_p99`**: 99th percentile of IS weights
215-
- Should be close to `rollout_is_threshold` if truncation is working
216-
217151
#### **Effective Sample Size**
218152

219153
- **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting
@@ -427,7 +361,7 @@ if not is_healthy:
427361
print(" - Checking if rollout and training policies are too different")
428362
```
429363

430-
### Step 3: Test Your Training
364+
### Running Examples
431365

432366
Start with the basic token-level truncate configuration:
433367
```bash
@@ -536,29 +470,23 @@ def plot_is_metrics(metrics_history):
536470
axes[0, 2].set_xlabel('Step')
537471
axes[0, 2].legend()
538472
539-
# Plot 4: IS weight distribution (latest step)
540-
latest_idx = -1
541-
percentiles = [25, 50, 75, 95, 99]
542-
values = [metrics_history[f'mismatch/rollout_is_p{p}'][latest_idx] for p in percentiles]
543-
axes[1, 0].bar([f'p{p}' for p in percentiles], values)
544-
axes[1, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
545-
axes[1, 0].set_title('IS Weight Percentiles (Latest)')
473+
# Plot 4: KL divergence over time
474+
axes[1, 0].plot(metrics_history['mismatch/mismatch_kl'], label='KL')
475+
axes[1, 0].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')
476+
axes[1, 0].axhline(y=0, color='g', linestyle='--', alpha=0.3)
477+
axes[1, 0].set_title('KL Divergence')
478+
axes[1, 0].set_xlabel('Step')
546479
axes[1, 0].legend()
547480
548-
# Plot 5: KL divergence over time
549-
axes[1, 1].plot(metrics_history['mismatch/mismatch_kl'], label='KL')
550-
axes[1, 1].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')
551-
axes[1, 1].axhline(y=0, color='g', linestyle='--', alpha=0.3)
552-
axes[1, 1].set_title('KL Divergence')
481+
# Plot 5: PPL ratio over time
482+
axes[1, 1].plot(metrics_history['mismatch/mismatch_ppl_ratio'])
483+
axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
484+
axes[1, 1].set_title('PPL Ratio (Training/Rollout)')
553485
axes[1, 1].set_xlabel('Step')
554486
axes[1, 1].legend()
555487
556-
# Plot 6: PPL ratio over time
557-
axes[1, 2].plot(metrics_history['mismatch/mismatch_ppl_ratio'])
558-
axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
559-
axes[1, 2].set_title('PPL Ratio (Training/Rollout)')
560-
axes[1, 2].set_xlabel('Step')
561-
axes[1, 2].legend()
488+
# Hide unused subplot
489+
axes[1, 2].axis('off')
562490
563491
plt.tight_layout()
564492
plt.savefig('rollout_is_metrics.png', dpi=150)
@@ -573,11 +501,6 @@ metrics_history = {
573501
'mismatch/rollout_is_mean': [],
574502
'mismatch/rollout_is_eff_sample_size': [],
575503
'mismatch/rollout_is_veto_fraction': [],
576-
'mismatch/rollout_is_p25': [],
577-
'mismatch/rollout_is_p50': [],
578-
'mismatch/rollout_is_p75': [],
579-
'mismatch/rollout_is_p95': [],
580-
'mismatch/rollout_is_p99': [],
581504
'mismatch/mismatch_kl': [],
582505
'mismatch/mismatch_k3_kl': [],
583506
'mismatch/mismatch_ppl_ratio': [],
@@ -604,11 +527,6 @@ for step in range(num_steps):
604527
- **Computational overhead**: 1-3% depending on level
605528
- **Training stability**: Significantly improved when mismatch exists
606529

607-
## Backward Compatibility
608-
609-
**The old `tis_imp_ratio_cap` parameter is completely removed.** There is no backward compatibility mode.
610-
611-
All scripts and configurations must be updated to use the new rollout IS parameters.
612530

613531
## Testing
614532

@@ -628,15 +546,13 @@ Expected output: All tests pass ✓
628546

629547
- **Implementation**: `verl/trainer/ppo/mismatch_helper.py`
630548
- **Examples**: `examples/rollout_importance_sampling/`
631-
- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
549+
- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh`
632550

633551
## Summary
634552

635-
The new Rollout Importance Sampling implementation provides:
636-
- More robust handling of distribution mismatch
637-
- Better numerical stability
553+
Rollout Importance Sampling provides:
554+
- Robust handling of distribution mismatch
555+
- Numerical stability
638556
- ✅ Comprehensive metrics for monitoring
639557
- ✅ Flexibility for different scenarios
640558
- ✅ Memory-efficient computation
641-
642-
Migration is straightforward: replace `tis_imp_ratio_cap` with the new `rollout_is_*` parameters in the `algorithm` config section.

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ verl is fast with:
121121
examples/sandbox_fusion_example
122122
advance/rollout_trace.rst
123123
advance/rollout_skip.rst
124-
advance/rollout_is_migration.md
124+
advance/rollout_is.md
125125
advance/one_step_off
126126
advance/agent_loop
127127
advance/fully_async

examples/rollout_importance_sampling/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ Key metrics to watch (all prefixed with `mismatch/` in logs):
133133
### Distribution Metrics
134134
- `rollout_is_max`, `rollout_is_min`: Weight extremes
135135
- `rollout_is_std`: Standard deviation
136-
- `rollout_is_p50`, `rollout_is_p95`, `rollout_is_p99`: Percentiles
137136

138137
### Diagnostic Metrics
139138
- `rollout_is_ratio_fraction_high`: Fraction exceeding upper threshold

recipe/dapo/run_dapo_qwen2.5_32b_tis.sh renamed to recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ kl_coef=0.0
1616
use_kl_loss=False
1717
kl_loss_coef=0.0
1818

19-
# Rollout Importance Sampling parameters (matches original TIS with threshold=2)
19+
# Rollout Importance Sampling parameters
2020
rollout_is=True
2121
rollout_is_threshold=2.0
22-
rollout_is_threshold_lower=null # No lower bound (original TIS behavior)
23-
rollout_is_level=token # token-level (original TIS behavior)
24-
rollout_is_mode=truncate # truncate mode (original TIS behavior)
25-
rollout_is_veto_threshold=null # No veto (original TIS behavior)
22+
rollout_is_threshold_lower=null # No lower bound
23+
rollout_is_level=token # token-level
24+
rollout_is_mode=truncate # truncate mode
25+
rollout_is_veto_threshold=null # No veto
2626

2727
clip_ratio_low=0.2
2828
clip_ratio_high=0.28

tests/trainer/ppo/test_rollout_is.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_basic_rollout_is():
4747
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.1
4848
eos_mask = torch.ones(batch_size, seq_length, device=device)
4949

50-
# Test token-level truncate mode (equivalent to old TIS)
50+
# Test token-level truncate mode
5151
print("\n1. Testing token-level truncate mode...")
5252
weights_proto, metrics = compute_rollout_importance_weights(
5353
old_log_prob=old_log_prob,
@@ -180,11 +180,6 @@ def test_metrics_completeness():
180180
"mismatch/rollout_is_catastrophic_token_fraction",
181181
"mismatch/rollout_is_ratio_fraction_high",
182182
"mismatch/rollout_is_ratio_fraction_low",
183-
"mismatch/rollout_is_p25",
184-
"mismatch/rollout_is_p50",
185-
"mismatch/rollout_is_p75",
186-
"mismatch/rollout_is_p95",
187-
"mismatch/rollout_is_p99",
188183
]
189184

190185
# Expected mismatch/diagnostic metrics (also included now)

verl/trainer/config/algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class AlgoConfig(BaseConfig):
9393
use_pf_ppo: bool = False
9494
pf_ppo: dict[str, Any] = field(default_factory=dict)
9595
filter_groups: Optional[FilterGroupsConfig] = None
96-
# Rollout Importance Sampling (replaces legacy tis_imp_ratio_cap)
96+
# Rollout Importance Sampling
9797
# Controls computation of IS weights and mismatch metrics
9898
rollout_is_threshold: Optional[float] = None # null = disabled, float = enabled
9999
rollout_is_threshold_lower: Optional[float] = None

0 commit comments

Comments
 (0)