diff --git a/docs/data/transfer_queue.md b/docs/data/transfer_queue.md new file mode 100644 index 00000000000..4532d42ed56 --- /dev/null +++ b/docs/data/transfer_queue.md @@ -0,0 +1,128 @@ +# TransferQueue Data System + +Last updated: 09/28/2025. + +This doc introduce [TransferQueue](https://github.com/TransferQueue/TransferQueue), an asynchronous streaming data management system for efficient post-training. + + +

Overview

+ +TransferQueue is a high-performance data storage and transfer system with panoramic data visibility and streaming scheduling capabilities, optimized for efficient dataflow in post-training workflows. + +

+ +

+ + +TransferQueue offers **fine-grained, sample-level** data management capabilities, serving as a data gateway that decouples explicit data dependencies across computational tasks. This enables a divide-and-conquer approach, significantly simplifying the design of the algorithm controller. + + +

+ +

+ + + + +

Components

+ + + +### Control Plane: Panoramic Data Management + +In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. When all the required data fields are ready (i.e., written to the `TransferQueueStorage`), we know that this data sample can be consumed by downstream tasks. + +For consumption status, we record the consumption records for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even different computation tasks require the same data field, they can consume the data independently without interfering with each other. + + +

+ +

+ + +> In the future, we plan to support **load-balancing** and **dynamic batching** capabilities in the control plane. Besides, we will support data management for disaggregated frameworks where each rank manages the data retrieval by itself, rather than coordinated by a single controller. + +### Data Plane: Distributed Data Storage + +In the data plane, `TransferQueueStorageSimpleUnit` serves as a naive storage unit based on CPU memory, responsible for the actual storage and retrieval of data. Each storage unit can be deployed on a separate node, allowing for distributed data management. + +`TransferQueueStorageSimpleUnit` employs a 2D data structure as follows: + +- Each row corresponds to a training sample, assigned a unique index within the corresponding global batch. +- Each column represents the input/output data fields for computational tasks. + +This data structure design is motivated by the computational characteristics of the post-training process, where each training sample is generated in a relayed manner across task pipelines. It provides an accurate addressing capability, which allows fine-grained, concurrent data read/write operations in a streaming manner. + +

+ +

+ + +> In the future, we plan to implement a **general storage abstraction layer** to support various storage backends. Through this abstraction, we hope to integrate high-performance storage solutions such as [MoonCakeStore](https://github.com/kvcache-ai/Mooncake) to support device-to-device data transfer through RDMA, further enhancing data transfer efficiency for large-scale data. + + +### User Interface: Asynchronous & Synchronous Client + + +The interaction workflow of TransferQueue system is as follows: + +1. A process sends a read request to the `TransferQueueController`. +2. `TransferQueueController` scans the production and consumption metadata for each sample (row), and dynamically assembles a micro-batch metadata according to the load-balancing policy. This mechanism enables sample-level data scheduling. +3. The process retrieves the actual data from distributed storage units using the metadata provided by the controller. + +To simplify the usage of TransferQueue, we have encapsulated this process into `AsyncTransferQueueClient` and `TransferQueueClient`. These clients provide both asynchronous and synchronous interfaces for data transfer, allowing users to easily integrate TransferQueue to their framework. + + +> In the future, we will provide a `StreamingDataLoader` interface for disaggregated frameworks as discussed in [RFC#2662](https://github.com/volcengine/verl/discussions/2662). Leveraging this abstraction, each rank can automatically get its own data like `DataLoader` in PyTorch. The TransferQueue system will handle the underlying data scheduling and transfer logic caused by different parallelism strategies, significantly simplifying the design of disaggregated frameworks. + + +

Show Cases

+ +### General Usage + +The primary interaction points are `AsyncTransferQueueClient` and `TransferQueueClient`, serving as the communication interface with the TransferQueue system. + +Core interfaces: + +- (async_)get_meta(data_fields: list[str], batch_size:int, global_step:int, get_n_samples:bool, task_name:str) -> BatchMeta +- (async_)get_data(metadata:BatchMeta) -> TensorDict +- (async_)put(data:TensorDict, metadata:BatchMeta, global_step) +- (async_)clear(global_step: int) + + +We will soon release a detailed tutorial and API documentation. + + +### verl Example + + +The primary motivation for integrating TransferQueue to verl now is to **alleviate the data transfer bottleneck of the single controller `RayPPOTrainer`**. Currently, all `DataProto` objects must be routed through `RayPPOTrainer`, resulting in a single point bottleneck of the whole post-training system. + +![verl_dataflow_DataProto](https://cdn.nlark.com/yuque/0/2025/jpeg/23208217/1758704289414-bcc54228-716b-4d4a-ad3b-f9ace6d10fcf.jpeg) + +Leveraging TransferQueue, we separate experience data transfer from metadata dispatch by + +- Replacing `DataProto` with `BatchMeta` (metadata) and `TensorDict` (actual data) structures +- Preserving verl's original Dispatch/Collect logic via BatchMeta (maintaining single-controller debuggability) +- Accelerating data transfer by TransferQueue's distributed storage units + +![verl_dataflow_TransferQueue](https://cdn.nlark.com/yuque/0/2025/jpeg/23208217/1758704301666-0807dc06-766c-4a2d-9cde-889a6bb56b34.jpeg) + + +You may refer to the [recipe](https://github.com/TransferQueue/TransferQueue/tree/dev/recipe/simple_use_case), where we mimic the verl usage in both async & sync scenarios. + + + + + +

Citation

+Please kindly cite our paper if you find this repo is useful: + +```bibtex +@article{han2025asyncflow, + title={AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training}, + author={Han, Zhenyu and You, Ansheng and Wang, Haibo and Luo, Kui and Yang, Guang and Shi, Wenqi and Chen, Menglong and Zhang, Sicheng and Lan, Zeshun and Deng, Chunshi and others}, + journal={arXiv preprint arXiv:2507.01663}, + year={2025} +} +``` \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index e8467dc965a..06411b5cfcb 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -125,6 +125,7 @@ verl is fast with: advance/one_step_off advance/agent_loop advance/fully_async + data/transfer_queue.md .. toctree:: :maxdepth: 1 diff --git a/recipe/transfer_queue/agent_loop.py b/recipe/transfer_queue/agent_loop.py new file mode 100644 index 00000000000..871ae8025c0 --- /dev/null +++ b/recipe/transfer_queue/agent_loop.py @@ -0,0 +1,76 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import ray +from transfer_queue import BatchMeta + +import verl.experimental.agent_loop.agent_loop as agent_loop +from verl import DataProto + + +class AgentLoopManager(agent_loop.AgentLoopManager): + def generate_sequences(self, prompts: BatchMeta) -> BatchMeta: + """Split input batch and dispatch to agent loop workers. + + Args: + prompts (BatchMeta): Input batch. + + Returns: + BatchMeta: Output batch metadata. + """ + + if self.rm_micro_batch_size and len(prompts) % self.rm_micro_batch_size != 0: + raise ValueError( + f"The length of prompts {len(prompts)} cannot divide the world size of rm_wg {self.rm_micro_batch_size}" + ) + if self.config.actor_rollout_ref.rollout.free_cache_engine: + self.wake_up() + chunkes = prompts.chunk(len(self.agent_loop_workers)) + outputs = ray.get( + [ + worker.generate_sequences.remote(chunk) + for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) + ] + ) + output = BatchMeta.concat(outputs) + if self.config.actor_rollout_ref.rollout.free_cache_engine: + self.sleep() + + # calculate performance metrics + metrics = [output.extra_info.pop("metrics") for output in outputs] # List[List[Dict[str, str]]] + timing = self._performance_metrics(metrics, output) + + output.set_extra_info("timing", timing) + return output + + def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]: + timing = {} + t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk]) + t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk]) + timing["agent_loop/generate_sequences/min"] = t_generate_sequences.min() + timing["agent_loop/generate_sequences/max"] = t_generate_sequences.max() + timing["agent_loop/generate_sequences/mean"] = t_generate_sequences.mean() + timing["agent_loop/tool_calls/min"] = t_tool_calls.min() + timing["agent_loop/tool_calls/max"] = t_tool_calls.max() + timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean() + + return timing + + def create_transferqueue_client(self, controller_infos, storage_infos, role): + ray.get( + [ + worker.create_transferqueue_client.remote(controller_infos, storage_infos, role) + for worker in self.agent_loop_workers + ] + ) diff --git a/recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml b/recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml new file mode 100644 index 00000000000..4f9c574e8ee --- /dev/null +++ b/recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml @@ -0,0 +1,11 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +# config for TransferQueue +transfer_queue: + enable: True diff --git a/recipe/transfer_queue/main_ppo.py b/recipe/transfer_queue/main_ppo.py new file mode 100644 index 00000000000..ca4c0ae0915 --- /dev/null +++ b/recipe/transfer_queue/main_ppo.py @@ -0,0 +1,200 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.main_ppo import ( + TaskRunner as MainTaskRunner, +) +from verl.trainer.main_ppo import ( + create_rl_dataset, + create_rl_sampler, +) +from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import need_critic, need_reference_policy +from verl.utils.config import validate_config +from verl.utils.device import is_cuda_available + +from .ray_trainer import RayPPOTrainer + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config, task_runner_class=None) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + task_runner_class: For recipe to change TaskRunner. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + if config.transfer_queue.enable: + ray_init_kwargs["TRANSFER_QUEUE_ENABLE"] = "1" + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + if task_runner_class is None: + task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and config.global_profiler.get("steps") is not None + and len(config.global_profiler.get("steps", [])) > 0 + ): + from verl.utils.import_utils import is_nvtx_available + + assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = task_runner_class.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +class TaskRunner(MainTaskRunner): + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + self.add_critic_worker(config) + + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + self.add_reward_model_worker(config) + + # Add a reference policy worker if KL loss or KL reward is used. + self.add_ref_policy_worker(config, actor_rollout_cls) + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(self.role_worker_mapping), + use_critic=need_critic(config), + ) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Load the reward manager for training and validation. + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + + resource_pool_manager = self.init_resource_pool_mgr(config) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = RayPPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=self.role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + # Start the training process. + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py new file mode 100644 index 00000000000..d6adbddb676 --- /dev/null +++ b/recipe/transfer_queue/ray_trainer.py @@ -0,0 +1,1885 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import asyncio +import json +import logging +import math +import os +import uuid +from collections import defaultdict +from dataclasses import dataclass, field +from pprint import pprint +from typing import Any, Optional + +import numpy as np +import ray +import tensordict +import torch +from omegaconf import OmegaConf, open_dict +from packaging.version import parse as parse_version +from tensordict import TensorDict +from torch.utils.data import Dataset, Sampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm +from transfer_queue import ( + BatchMeta, + TransferQueueController, + TransferQueueStorageSimpleUnit, + get_placement_group, + process_zmq_server_info, +) + +from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.single_controller.ray import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.config import AlgoConfig +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + process_validation_metrics, +) +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.trainer.ppo.utils import ( + Role, + WorkerType, + need_critic, + need_reference_policy, + need_reward_model, +) +from verl.utils.checkpoint.checkpoint_manager import ( + find_latest_ckpt_path, + should_save_ckpt_esi, +) +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.metric import reduce_metrics +from verl.utils.rollout_skip import RolloutSkip +from verl.utils.seqlen_balancing import ( + get_seqlen_balanced_partitions, + log_seqlen_unbalance, +) +from verl.utils.torch_functional import masked_mean +from verl.utils.tracking import ValidationGenerationsLogger +from verl.utils.transferqueue_utils import ( + create_transferqueue_client, + get_transferqueue_client, + get_val_transferqueue_client, + tqbridge, +) + + +@dataclass +class ResourcePoolManager: + """ + Define a resource pool specification. Resource pool will be initialized first. + """ + + resource_pool_spec: dict[str, list[int]] + mapping: dict[Role, str] + resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) + + def create_resource_pool(self): + """Create Ray resource pools for distributed training. + + Initializes resource pools based on the resource pool specification, + with each pool managing GPU resources across multiple nodes. + For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups. + For Megatron backend, uses max_colocate_count>1 for different models. + """ + for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): + # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool + # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. + # For Megatron backend, we recommend using max_colocate_count>1 + # that can utilize different WorkerGroup for differnt models + resource_pool = RayResourcePool( + process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + ) + self.resource_pool_dict[resource_pool_name] = resource_pool + + self._check_resource_available() + + def get_resource_pool(self, role: Role) -> RayResourcePool: + """Get the resource pool of the worker_cls""" + return self.resource_pool_dict[self.mapping[role]] + + def get_n_gpus(self) -> int: + """Get the number of gpus in this cluster.""" + return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + + def _check_resource_available(self): + """Check if the resource pool can be satisfied in this ray cluster.""" + node_available_resources = ray._private.state.available_resources_per_node() + node_available_gpus = { + node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) + for node, node_info in node_available_resources.items() + } + + # check total required gpus can be satisfied + total_available_gpus = sum(node_available_gpus.values()) + total_required_gpus = sum( + [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] + ) + if total_available_gpus < total_required_gpus: + raise ValueError( + f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" + ) + + +@tqbridge(put_data=False) +def compute_reward_decorated(data, reward_fn): + return compute_reward(data, reward_fn) + + +@tqbridge(put_data=False) +def compute_reward_async_decorated(data, reward_fn): + return compute_reward_async.remote(data, reward_fn) + + +@tqbridge(put_data=False) +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): + """Apply KL penalty to the token-level rewards. + + This function computes the KL divergence between the reference policy and current policy, + then applies a penalty to the token-level rewards based on this divergence. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty. + kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl". + + Returns: + tuple: A tuple containing: + - The updated data with token-level rewards adjusted by KL penalty + - A dictionary of metrics related to the KL penalty + """ + response_mask = data.batch["response_mask"] + token_level_scores = data.batch["token_level_scores"] + batch_size = data.batch.batch_size[0] + + # compute kl between ref_policy and current policy + # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. + kld = core_algos.kl_penalty( + data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty + ) # (batch_size, response_length) + kld = kld * response_mask + beta = kl_ctrl.value + + token_level_rewards = token_level_scores - beta * kld + + current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence + current_kl = torch.mean(current_kl, dim=0).item() + + # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 + kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) + + metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} + + return token_level_rewards, metrics + + +def compute_response_mask(batch_meta: BatchMeta, data_system_client): + """Compute the attention mask for the response part of the sequence. + + This function extracts the portion of the attention mask that corresponds to the model's response, + which is used for masking computations that should only apply to response tokens. + + Args: + batch_meta (BatchMeta): The data containing batched model outputs and inputs. + + Returns: + BatchMeta: The BatchMeta of attention mask for the response tokens. + """ + data = asyncio.run(data_system_client.async_get_data(batch_meta)) + + responses = data["responses"] + response_length = responses.size(1) + attention_mask = data["attention_mask"] + response_mask = attention_mask[:, -response_length:] + output = TensorDict({"response_mask": response_mask}, batch_size=response_mask.size(0)) + + asyncio.run(data_system_client.async_put(data=output, metadata=batch_meta)) + batch_meta.add_fields(output) + + return batch_meta + + +@tqbridge(put_data=False) +def compute_advantage( + data: DataProto, + adv_estimator: AdvantageEstimator, + gamma: float = 1.0, + lam: float = 1.0, + num_repeat: int = 1, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> tuple[Any, Any]: + """Compute advantage estimates for policy optimization. + + This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. + The advantage estimates are used to guide policy optimization in RL algorithms. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). + gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. + lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. + num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. + norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in + GRPO. Defaults to True. + config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None. + + Returns: + tuple: A tuple containing: + - advantages: The computed advantage estimates. + - returns: The computed returns. + """ + # prepare response group + if adv_estimator == AdvantageEstimator.GAE: + # Compute advantages and returns using Generalized Advantage Estimation (GAE) + advantages, returns = core_algos.compute_gae_advantage_return( + token_level_rewards=data.batch["token_level_rewards"], + values=data.batch["values"], + response_mask=data.batch["response_mask"], + gamma=gamma, + lam=lam, + ) + # TODO: (TQ) adapt core_algos.compute_pf_ppo_reweight_data function to support transfer queue + if config.get("use_pf_ppo", False): + data = core_algos.compute_pf_ppo_reweight_data( + data, + config.pf_ppo.get("reweight_method"), + config.pf_ppo.get("weight_pow"), + ) + elif adv_estimator == AdvantageEstimator.GRPO: + # Initialize the mask for GRPO calculation + grpo_calculation_mask = data.batch["response_mask"] + # Call compute_grpo_outcome_advantage with parameters matching its definition + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=grpo_calculation_mask, + index=data.non_tensor_batch["uid"], + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + else: + # handle all other adv estimator type other than GAE and GRPO + adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) + adv_kwargs = { + "token_level_rewards": data.batch["token_level_rewards"], + "response_mask": data.batch["response_mask"], + "config": config, + } + if "uid" in data.non_tensor_batch: # optional + adv_kwargs["index"] = data.non_tensor_batch["uid"] + if "reward_baselines" in data.batch: # optional + adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] + + # calculate advantage estimator + advantages, returns = adv_estimator_fn(**adv_kwargs) + return advantages, returns + + +@tqbridge(put_data=False) +def compute_data_metrics_decorated(batch, use_critic: bool = True): + return compute_data_metrics(batch, use_critic) + + +@tqbridge(put_data=False) +def compute_timing_metrics_decorated(batch, timing_raw: dict[str, float]) -> dict[str, Any]: + return compute_timing_metrics(batch, timing_raw) + + +@tqbridge(put_data=False) +def compute_throughout_metrics_decorated(batch, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]: + return compute_throughout_metrics(batch, timing_raw, n_gpus) + + +@tqbridge(put_data=False) +def calculate_debug_metrics_decorated(data): + from verl.utils.debug.metrics import calculate_debug_metrics + + return calculate_debug_metrics(data) + + +@tqbridge(put_data=False) +def compute_val_reward_decorated(reward_fn, data, return_dict): + return reward_fn(data, return_dict) + + +class RayPPOTrainer: + """Distributed PPO trainer using Ray for scalable reinforcement learning. + + This trainer orchestrates distributed PPO training across multiple nodes and GPUs, + managing actor rollouts, critic training, and reward computation with Ray backend. + Supports various model architectures including FSDP, Megatron, vLLM, and SGLang integration. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name=None, + ): + """ + Initialize distributed PPO trainer with Ray backend. + Note that this trainer runs on the driver process on a single CPU/GPU node. + + Args: + config: Configuration object containing training parameters. + tokenizer: Tokenizer used for encoding and decoding text. + role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. + resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. + ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. + processor: Optional data processor, used for multimodal data + reward_fn: Function for computing rewards during training. + val_reward_fn: Function for computing rewards during validation. + train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. + collate_fn: Function to collate data samples into batches. + train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. + """ + + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, "Currently, only support hybrid engine" + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = need_reference_policy(self.role_worker_mapping) + self.use_rm = need_reward_model(self.role_worker_mapping) + self.use_critic = need_critic(self.config) + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name if device_name else self.config.trainer.device + self.validation_generations_logger = ValidationGenerationsLogger( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + ) + + # if ref_in_actor is True, the reference policy will be actor without lora applied + self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 + + # define in-reward KL control + # kl loss control currently not suppoorted + if self.config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) + + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + self.data_system_client = self._initialize_train_data_system( + self.config.data.train_batch_size, self.config.actor_rollout_ref.rollout.n + ) + self.val_data_system_client = self._initialize_val_data_system( + self.val_batch_size, self.config.actor_rollout_ref.rollout.val_kwargs.n + ) + + def _initialize_train_data_system(self, global_batch_size, num_n_samples, role="train"): + # 1. initialize TransferQueueStorage + total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples + self.data_system_storage_units = {} + storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) + for storage_unit_rank in range(self.config.trainer.num_data_storage_units): + storage_node = TransferQueueStorageSimpleUnit.options( + placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank + ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) + self.data_system_storage_units[storage_unit_rank] = storage_node + logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") + + # 2. initialize TransferQueueController + # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly + # one controller for a single WorkerGroup. + self.data_system_controllers = {} + controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) + for controller_rank in range(self.config.trainer.num_data_controllers): + self.data_system_controllers[controller_rank] = TransferQueueController.options( + placement_group=controller_placement_group, placement_group_bundle_index=controller_rank + ).remote( + num_storage_units=self.config.trainer.num_data_storage_units, + global_batch_size=global_batch_size, + num_global_batch=self.config.trainer.num_global_batch, + num_n_samples=num_n_samples, + ) + logging.info(f"TransferQueueController #{controller_rank} has been created.") + + # 3. register controller & storage + self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers) + self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) + + ray.get( + [ + storage_unit.register_controller_info.remote(self.data_system_controller_infos) + for storage_unit in self.data_system_storage_units.values() + ] + ) + + # 4. create client + # each client should be allocated to exactly one controller + create_transferqueue_client( + client_id="Trainer-" + role, + controller_infos=self.data_system_controller_infos, + storage_infos=self.data_system_storage_unit_infos, + ) + data_system_client = get_transferqueue_client() + return data_system_client + + def _initialize_val_data_system(self, global_batch_size, num_n_samples, role="val"): + # 1. initialize TransferQueueStorage + total_storage_size = global_batch_size * self.config.trainer.num_global_batch * num_n_samples + self.val_data_system_storage_units = {} + storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) + for storage_unit_rank in range(self.config.trainer.num_data_storage_units): + storage_node = TransferQueueStorageSimpleUnit.options( + placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank + ).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) + self.val_data_system_storage_units[storage_unit_rank] = storage_node + logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") + + # 2. initialize TransferQueueController + # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly + # one controller for a single WorkerGroup. + self.val_data_system_controllers = {} + controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) + for controller_rank in range(self.config.trainer.num_data_controllers): + self.val_data_system_controllers[controller_rank] = TransferQueueController.options( + placement_group=controller_placement_group, placement_group_bundle_index=controller_rank + ).remote( + num_storage_units=self.config.trainer.num_data_storage_units, + global_batch_size=global_batch_size, + num_global_batch=self.config.trainer.num_global_batch, + num_n_samples=num_n_samples, + ) + logging.info(f"TransferQueueController #{controller_rank} has been created.") + + # 3. register controller & storage + self.val_data_system_controller_infos = process_zmq_server_info(self.val_data_system_controllers) + self.val_data_system_storage_unit_infos = process_zmq_server_info(self.val_data_system_storage_units) + + ray.get( + [ + storage_unit.register_controller_info.remote(self.val_data_system_controller_infos) + for storage_unit in self.val_data_system_storage_units.values() + ] + ) + + # 4. create client + # each client should be allocated to exactly one controller + create_transferqueue_client( + client_id="Trainer-" + role, + controller_infos=self.val_data_system_controller_infos, + storage_infos=self.val_data_system_storage_unit_infos, + ) + data_system_client = get_val_transferqueue_client() + return data_system_client + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset( + self.config.data.train_files, self.config.data, self.tokenizer, self.processor + ) + if val_dataset is None: + val_dataset = create_rl_dataset( + self.config.data.val_files, self.config.data, self.tokenizer, self.processor + ) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + num_workers = self.config.data["dataloader_num_workers"] + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + self.val_batch_size = val_batch_size + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print( + f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " + f"{len(self.val_dataloader)}" + ) + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): + """Dump rollout/validation samples as JSONL.""" + os.makedirs(dump_path, exist_ok=True) + filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") + + n = len(inputs) + base_data = { + "input": inputs, + "output": outputs, + "gts": gts, + "score": scores, + "step": [self.global_steps] * n, + } + + for k, v in reward_extra_infos_dict.items(): + if len(v) == n: + base_data[k] = v + + lines = [] + for i in range(n): + entry = {k: v[i] for k, v in base_data.items()} + lines.append(json.dumps(entry, ensure_ascii=False)) + + with open(filename, "w") as f: + f.write("\n".join(lines) + "\n") + + print(f"Dumped generations to {filename}") + + def _log_rollout_data( + self, log_rollout_meta: BatchMeta, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str + ): + """ + Log rollout data to disk. + + Args: + log_rollout_meta (BatchMeta): The batch_meta of rollout data + reward_extra_infos_dict (dict): Additional reward information to log + timing_raw (dict): Timing information for profiling + rollout_data_dir (str): Directory path to save the rollout data + """ + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + data = asyncio.run(self.data_system_client.async_get_data(log_rollout_meta)) + + inputs = self.tokenizer.batch_decode(data["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(data["responses"], skip_special_tokens=True) + scores = data["token_level_scores"].sum(-1).cpu().tolist() + sample_gts = [item.get("ground_truth", None) for item in data.get("reward_model", {})] + + reward_extra_infos_to_dump = reward_extra_infos_dict.copy() + if "request_id" in log_rollout_meta.field_names: + reward_extra_infos_dict.setdefault( + "request_id", + data["request_id"].tolist(), + ) + + self._dump_generations( + inputs=inputs, + outputs=outputs, + gts=sample_gts, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_to_dump, + dump_path=rollout_data_dir, + ) + + def _maybe_log_val_generations(self, inputs, outputs, scores): + """Log a table of validation samples to the configured logger (wandb or swanlab)""" + + generations_to_log = self.config.trainer.log_val_generations + + if generations_to_log == 0: + return + + import numpy as np + + # Create tuples of (input, output, score) and sort by input text + samples = list(zip(inputs, outputs, scores, strict=True)) + samples.sort(key=lambda x: x[0]) # Sort by input text + + # Use fixed random seed for deterministic shuffling + rng = np.random.RandomState(42) + rng.shuffle(samples) + + # Take first N samples after shuffling + samples = samples[:generations_to_log] + + # Log to each configured logger + self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + + def _get_gen_batch(self, batch: DataProto) -> DataProto: + reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() + + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), + ) + + # For agent loop, we need reward model keys to compute score. + if self.async_rollout_mode: + gen_batch.non_tensor_batch.update(batch.non_tensor_batch) + + return gen_batch + + def _validate(self): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_gts = [] + sample_scores = [] + sample_turns = [] + sample_uids = [] + + for test_data in self.val_dataloader: + if "uid" not in test_data.keys(): + test_data["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(test_data["input_ids"]))], dtype=object + ) + + # repeat test data + repeated_test_data = self.repeat_dict( + test_data, repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) + + test_batch: TensorDict = self.dict_to_tensordict(repeated_test_data) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0]["reward_model"]["style"] == "model": + return {} + + asyncio.run(self.val_data_system_client.async_put(data=test_batch, global_step=self.global_steps - 1)) + + # Store original inputs + batch_meta = asyncio.run( + self.val_data_system_client.async_get_meta( + data_fields=["input_ids", "uid", "reward_model"], + batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="get_data", + ) + ) + data = asyncio.run(self.val_data_system_client.async_get_data(batch_meta)) + input_ids = data["input_ids"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + sample_uids.extend(data["uid"]) + + ground_truths = [item.get("ground_truth", None) for item in data.get("reward_model", {})] + sample_gts.extend(ground_truths) + + test_gen_meta = asyncio.run( + self.val_data_system_client.async_get_meta( + data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields + batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, + global_step=self.global_steps - 1, # self.global_steps start from 1 + get_n_samples=False, + task_name="generate_sequences", + ) + ) + test_gen_meta.extra_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, + "global_steps": self.global_steps, + } + print(f"test_gen_batch meta info: {test_gen_meta.extra_info}") + + # TODO: (TQ) Support padding and unpadding to make DataProto divisible by dp_size with TransferQueue + if not self.async_rollout_mode: + test_output_gen_meta = self.actor_rollout_wg.generate_sequences(test_gen_meta) + else: + test_output_gen_meta = self.async_rollout_manager.generate_sequences(test_gen_meta) + + test_batch_meta = test_gen_meta.union(test_output_gen_meta) + + print("validation generation end") + + # Store generated outputs + test_response_meta = asyncio.run( + self.val_data_system_client.async_get_meta( + data_fields=["responses"], + batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, + global_step=self.global_steps - 1, # self.global_steps start from 1 + get_n_samples=False, + task_name="get_response", + ) + ) + data = asyncio.run(self.val_data_system_client.async_get_data(test_response_meta)) + output_ids = data["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch_meta.set_extra_info("validate", True) + + # evaluate using reward_function + if self.val_reward_fn is None: + raise ValueError("val_reward_fn must be provided for validation.") + + compute_reward_fields = [ + "responses", + "prompts", + "attention_mask", + "reward_model", + "data_source", + ] + if "rm_scores" in batch_meta.field_names: + compute_reward_fields = ["rm_scores"] + val_reward_meta = asyncio.run( + self.val_data_system_client.async_get_meta( + data_fields=compute_reward_fields, + batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="compute_reward", + ) + ) + val_reward_meta.update_extra_info(test_batch_meta.extra_info) + result = compute_val_reward_decorated(self.val_reward_fn, val_reward_meta, return_dict=True) + reward_tensor = result["reward_tensor"] + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") + if "reward_extra_info" in result: + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) + print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") + + # collect num_turns of each prompt + if "__num_turns__" in test_batch_meta.field_names: + num_turns_meta = asyncio.run( + self.val_data_system_client.async_get_meta( + data_fields=["__num_turns__"], + batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, + global_step=self.global_steps - 1, # self.global_steps start from 1 + get_n_samples=False, + task_name="get_num_turns", + ) + ) + data = asyncio.run(self.val_data_system_client.async_get_data(num_turns_meta)) + sample_turns.append(data["__num_turns__"]) + + data_source = ["unknown"] * reward_tensor.shape[0] + if "data_source" in test_batch_meta.field_names: + data_source_meta = asyncio.run( + self.val_data_system_client.async_get_meta( + data_fields=["data_source"], + batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, + global_step=self.global_steps - 1, # self.global_steps start from 1 + get_n_samples=False, + task_name="get_data_source", + ) + ) + data = asyncio.run(self.val_data_system_client.async_get_data(data_source_meta)) + data_source = data["data_source"] + + data_source_lst.append(data_source) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + gts=sample_gts, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + data_sources = np.concatenate(data_source_lst, axis=0) + + data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + if len(sample_turns) > 0: + sample_turns = np.concatenate(sample_turns) + metric_dict["val-aux/num_turns/min"] = sample_turns.min() + metric_dict["val-aux/num_turns/max"] = sample_turns.max() + metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() + + asyncio.run(self.val_data_system_client.async_clear(self.global_steps - 1)) + return metric_dict + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor_rollout", + ) + self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cfg = omega_conf_to_dataclass(self.config.critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role="ref", + ) + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + self.rm_wg = None + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg["actor_rollout"] + self.actor_rollout_wg.init_model() + + # set transferqueue server info for each worker + for _, wg in all_wg.items(): + wg.create_transferqueue_client( + self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train" + ) + wg.create_transferqueue_client( + self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" + ) + + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.mode == "async": + from .agent_loop import AgentLoopManager + + self.async_rollout_mode = True + self.async_rollout_manager = AgentLoopManager( + config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg + ) + + self.async_rollout_manager.create_transferqueue_client( + self.data_system_controller_infos, self.data_system_storage_unit_infos, role="train" + ) + self.async_rollout_manager.create_transferqueue_client( + self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" + ) + + def _save_checkpoint(self): + from verl.utils.fs import local_mkdir_safe + + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") + + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) + + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + if remove_previous_ckpt_in_save: + print( + "Warning: remove_previous_ckpt_in_save is deprecated," + + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" + ) + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + max_critic_ckpt_to_keep = ( + self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) + + if self.use_critic: + critic_local_path = os.path.join(local_global_step_folder, "critic") + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + ) + self.critic_wg.save_checkpoint( + critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + ) + + # save dataloader + local_mkdir_safe(local_global_step_folder) + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + + # latest checkpointed iteration tracker (for atomic usage) + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.global_steps)) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("Training from scratch") + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"Load from checkpoint folder: {global_step_folder}") + # set global step + self.global_steps = int(global_step_folder.split("global_step_")[-1]) + + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, "critic") + # load actor + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + # load critic + if self.use_critic: + self.critic_wg.load_checkpoint( + critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + if os.path.exists(dataloader_local_path): + dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + else: + print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + + def _start_profiling(self, do_profile: bool) -> None: + """Start profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + if self.use_reference_policy: + self.ref_policy_wg.start_profile(profile_step=self.global_steps) + if self.use_critic: + self.critic_wg.start_profile(profile_step=self.global_steps) + if self.use_rm: + self.rm_wg.start_profile(profile_step=self.global_steps) + + def _stop_profiling(self, do_profile: bool) -> None: + """Stop profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + if self.use_critic: + self.critic_wg.stop_profile() + if self.use_rm: + self.rm_wg.stop_profile() + + def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen"): + """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens""" + data = asyncio.run(data_system_client.async_get_data(batch)) + + attention_mask = data["attention_mask"] + batch_size = attention_mask.shape[0] + global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + world_size = self.actor_rollout_wg.world_size + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) + # reorder based on index. The data will be automatically equally partitioned by dispatch function + global_idx = [j for partition in global_partition_lst for j in partition] + global_balance_stats = log_seqlen_unbalance( + seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + ) + metrics.update(global_balance_stats) + return global_idx + + @classmethod + def repeat_dict( + cls, batch_dict: dict[str, torch.Tensor | np.ndarray], repeat_times=2, interleave=True + ) -> dict[str, torch.Tensor | np.ndarray]: + """ + Repeat the batch dict a specified number of times. + + Args: + repeat_times (int): Number of times to repeat the data. + interleave (bool): Whether to interleave the repeated data. + + Returns: + dict: A new dict with repeated data. + """ + if repeat_times == 1: + return batch_dict + + repeated_batch_dict = {} + if batch_dict: + if interleave: + # Interleave the data + for key, val in batch_dict.items(): + if isinstance(val, torch.Tensor): + repeated_batch_dict[key] = val.repeat_interleave(repeat_times, dim=0) + elif isinstance(val, np.ndarray): + repeated_batch_dict[key] = np.repeat(val, repeat_times, axis=0) + else: + raise ValueError(f"Unsupported type in data {type(val)}") + else: + # Stack the data + for key, val in batch_dict.items(): + if isinstance(val, torch.Tensor): + repeated_batch_dict[key] = ( + val.unsqueeze(0).expand(repeat_times, *val.shape).reshape(-1, *val.shape[1:]) + ) + elif isinstance(val, np.ndarray): + repeated_batch_dict[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) + else: + raise ValueError(f"Unsupported type in data {type(val)}") + return repeated_batch_dict + + @classmethod + def dict_to_tensordict(cls, data: dict[str, torch.Tensor | np.ndarray]) -> TensorDict: + """ + Create a TensorDict from a dict of tensors and non_tensors. + Note that this requires tensordict version at least 0.10 + """ + assert parse_version(tensordict.__version__) >= parse_version("0.10"), ( + "Storing non-tensor data in TensorDict at least requires tensordict version 0.10" + ) + tensors_batch = {} + batch_size = None + + for key, val in data.items(): + if isinstance(val, torch.Tensor | np.ndarray): + tensors_batch[key] = val + else: + raise ValueError(f"Unsupported type in data {type(val)}") + + if batch_size is None: + batch_size = len(val) + else: + assert len(val) == batch_size + + if batch_size is None: + batch_size = [] + else: + batch_size = [batch_size] + + return TensorDict(tensors_batch, batch_size=batch_size) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + base_get_meta_kwargs = dict( + batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, # self.global_steps starts from 1 + get_n_samples=False, + ) + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + + # add uid to batch + batch_dict["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch_dict["input_ids"]))], dtype=object + ) + # When n > 1, repeat input data before putting to data system, simulating DataProto repeat. + repeated_batch_dict = self.repeat_dict( + batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) + asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1)) + + gen_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=list(batch.keys()), # TODO: (TQ) Get metadata by specified fields + task_name="generate_sequences", + **base_get_meta_kwargs, + ) + ) + # pass global_steps to trace + gen_meta.set_extra_info("global_steps", self.global_steps) + + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_output_meta = self.actor_rollout_wg.generate_sequences(gen_meta) + else: + gen_output_meta = self.async_rollout_manager.generate_sequences(gen_meta) + timing_raw.update(gen_output_meta.extra_info["timing"]) + gen_output_meta.extra_info.pop("timing", None) + + # TODO: (TQ) support transfer queue + # if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + # if self.reward_fn is None: + # raise ValueError("A reward_fn is required for REMAX advantage estimation.") + # + # with marked_timer("gen_max", timing_raw, color="purple"): + # gen_baseline_meta = deepcopy(gen_meta) + # gen_baseline_meta.extra_info["do_sample"] = False + # if not self.async_rollout_mode: + # gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_meta) + # else: + # gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_meta) + # batch = batch.union(gen_baseline_output) + # reward_baseline_tensor = self.reward_fn(batch) + # reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + # + # batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + # + # batch.batch["reward_baselines"] = reward_baseline_tensor + # + # del gen_baseline_batch, gen_baseline_output + + batch_meta: BatchMeta = gen_meta.union(gen_output_meta) + + if "response_mask" not in batch_meta.field_names: + response_mask_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=["responses", "attention_mask"], + task_name="compute_response_mask", + **base_get_meta_kwargs, + ) + ) + response_mask_output_meta = compute_response_mask(response_mask_meta, self.data_system_client) + batch_meta = batch_meta.union(response_mask_output_meta) + + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + balanced_idx = None + if self.config.trainer.balance_batch: + attention_mask_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=["attention_mask"], + task_name="balance_batch", + **base_get_meta_kwargs, + ) + ) + + balanced_idx = self._balance_batch( + attention_mask_meta, self.data_system_client, metrics=metrics + ) + batch_meta.reorder(balanced_idx) + + # compute global_valid tokens + data = asyncio.run(self.data_system_client.async_get_data(attention_mask_meta)) + batch_meta.extra_info["global_token_num"] = torch.sum(data["attention_mask"], dim=-1).tolist() + + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm and "rm_scores" not in batch_meta.field_names: + reward_meta = self.rm_wg.compute_rm_score(batch_meta) + batch_meta = batch_meta.union(reward_meta) + + compute_reward_fields = [ + "responses", + "prompts", + "attention_mask", + "reward_model", + "data_source", + ] + if "rm_scores" in batch_meta.field_names: + compute_reward_fields.append("rm_scores") + compute_reward_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=compute_reward_fields, + task_name="compute_reward", + **base_get_meta_kwargs, + ) + ) + compute_reward_meta.reorder(balanced_idx) + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async_decorated( + data=compute_reward_meta, + reward_fn=self.reward_fn, + ) + else: + reward_tensor, reward_extra_infos_dict = compute_reward_decorated( + compute_reward_meta, self.reward_fn + ) + batch_meta = batch_meta.union(compute_reward_meta) + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=[ + "input_ids", + "attention_mask", + "position_ids", + "prompts", + "responses", + "response_mask", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability", + ], + task_name="compute_log_prob", + **base_get_meta_kwargs, + ) + ) + old_log_prob_meta.reorder(balanced_idx) + + old_log_prob_output_meta = self.actor_rollout_wg.compute_log_prob(old_log_prob_meta) + data = asyncio.run(self.data_system_client.async_get_data(old_log_prob_output_meta)) + entropys = data["entropys"] + response_masks = data["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + + batch_meta = batch_meta.union(old_log_prob_output_meta) + + if "rollout_log_probs" in batch_meta.field_names: + # TODO: we may want to add diff of probs too. + data_fields = ["rollout_log_probs", "old_log_probs", "responses"] + if "response_mask" in batch_meta.field_names: + data_fields.append("response_mask") + if "attention_mask" in batch_meta.field_names: + data_fields.append("attention_mask") + calculate_debug_metrics_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=data_fields, + task_name="calculate_debug_metrics", + **base_get_meta_kwargs, + ) + ) + calculate_debug_metrics_meta.reorder(balanced_idx) + + metrics.update(calculate_debug_metrics_decorated(calculate_debug_metrics_meta)) + + if self.use_reference_policy: + # compute reference log_prob + ref_log_prob_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=[ + "input_ids", + "attention_mask", + "position_ids", + "prompts", + "responses", + "response_mask", + "old_log_probs", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability", + ], + task_name="compute_ref_log_prob", + **base_get_meta_kwargs, + ) + ) + ref_log_prob_meta.reorder(balanced_idx) + with marked_timer("ref", timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob_output_meta = self.ref_policy_wg.compute_ref_log_prob(ref_log_prob_meta) + else: + ref_log_prob_output_meta = self.actor_rollout_wg.compute_ref_log_prob(ref_log_prob_meta) + batch_meta = batch_meta.union(ref_log_prob_output_meta) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values_meta = self.critic_wg.compute_values(batch_meta) + batch_meta = batch_meta.union(values_meta) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + reward_td = TensorDict({"token_level_scores": reward_tensor}, batch_size=reward_tensor.size(0)) + asyncio.run(self.data_system_client.async_put(data=reward_td, metadata=batch_meta)) + batch_meta.add_fields(reward_td) + + if reward_extra_infos_dict: + reward_extra_infos_dict_new = {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + reward_extra_infos_td = self.dict_to_tensordict(reward_extra_infos_dict_new) + asyncio.run( + self.data_system_client.async_put(data=reward_extra_infos_td, metadata=batch_meta) + ) + batch_meta.add_fields(reward_extra_infos_td) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + apply_kl_penalty_fields = [ + "response_mask", + "token_level_scores", + "old_log_probs", + "ref_log_prob", + ] + apply_kl_penalty_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=apply_kl_penalty_fields, + task_name="apply_kl_penalty", + **base_get_meta_kwargs, + ) + ) + apply_kl_penalty_meta.reorder(balanced_idx) + token_level_rewards, kl_metrics = apply_kl_penalty( + apply_kl_penalty_meta, + kl_ctrl=self.kl_ctrl_in_reward, + kl_penalty=self.config.algorithm.kl_penalty, + ) + token_level_rewards_td = TensorDict( + {"token_level_rewards": token_level_rewards}, batch_size=token_level_rewards.size(0) + ) + asyncio.run( + self.data_system_client.async_put( + data=token_level_rewards_td, metadata=apply_kl_penalty_meta + ) + ) + apply_kl_penalty_meta.add_fields(token_level_rewards_td) + + metrics.update(kl_metrics) + batch_meta = batch_meta.union(apply_kl_penalty_meta) + else: + token_level_scores_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=["token_level_scores"], + task_name="token_level_scores", + **base_get_meta_kwargs, + ) + ) + token_level_scores_meta.reorder(balanced_idx) + data = asyncio.run(self.data_system_client.async_get_data(token_level_scores_meta)) + token_level_rewards_td = TensorDict( + {"token_level_rewards": data["token_level_scores"]}, + batch_size=data["token_level_scores"].size(0), + ) + asyncio.run( + self.data_system_client.async_put( + data=token_level_rewards_td, metadata=token_level_scores_meta + ) + ) + batch_meta.add_fields(token_level_rewards_td) + + # compute advantages, executed on the driver process + + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + assert "response_mask" in batch_meta.field_names, ( + f"`response_mask` must be in batch_meta {batch_meta.field_names} for advantage computation" + ) + compute_advantage_fields = [ + "response_mask", + "token_level_rewards", + ] + if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: + compute_advantage_fields.append("values") + elif self.config.algorithm.adv_estimator == AdvantageEstimator.GRPO: + compute_advantage_fields.append("uid") + else: + if "uid" in batch_meta.field_names: + compute_advantage_fields.append("uid") + if "reward_baselines" in batch_meta.field_names: + compute_advantage_fields.append("reward_baselines") + + compute_advantage_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=compute_advantage_fields, + task_name="compute_advantage", + **base_get_meta_kwargs, + ) + ) + compute_advantage_meta.reorder(balanced_idx) + + advantages, returns = compute_advantage( + compute_advantage_meta, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + advantages_td = TensorDict( + {"advantages": advantages, "returns": returns}, batch_size=advantages.size(0) + ) + asyncio.run( + self.data_system_client.async_put(data=advantages_td, metadata=compute_advantage_meta) + ) + compute_advantage_meta.add_fields(advantages_td) + + batch_meta = batch_meta.union(compute_advantage_meta) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output_meta = self.critic_wg.update_critic(batch_meta) + batch_meta = batch_meta.union(critic_output_meta) + critic_output_metrics = reduce_metrics(critic_output_meta.extra_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + batch_meta.extra_info["multi_turn"] = ( + self.config.actor_rollout_ref.rollout.multi_turn.enable + ) + + update_actor_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=[ + "input_ids", + "attention_mask", + "position_ids", + "prompts", + "responses", + "response_mask", + "old_log_probs", + "ref_log_prob", + "advantages", + "returns", + "token_level_rewards", + "token_level_scores", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability", + ], + batch_size=self.config.data.train_batch_size + * self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="update_actor", + ) + ) + update_actor_meta.reorder(balanced_idx) + update_actor_meta.set_extra_info( + "global_token_num", batch_meta.get_extra_info("global_token_num") + ) + update_actor_meta.set_extra_info("temperature", batch_meta.get_extra_info("temperature")) + + actor_output_meta = self.actor_rollout_wg.update_actor(update_actor_meta) + batch_meta = batch_meta.union(actor_output_meta) + actor_output_metrics = reduce_metrics(actor_output_meta.extra_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + data_fields = ["prompts", "responses", "token_level_scores", "reward_model"] + if "request_id" in batch_meta.field_names: + data_fields.append("request_id") + log_rollout_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=data_fields, + batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, + global_step=self.global_steps - 1, + get_n_samples=False, + task_name="log_rollout", + ) + ) + log_rollout_meta.reorder(balanced_idx) + self._log_rollout_data(log_rollout_meta, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # TODO: clear meta after iteration + + # TODO: validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + compute_data_metrics_fields = [ + "token_level_rewards", + "token_level_scores", + "advantages", + "returns", + "responses", + "attention_mask", + "response_mask", + ] + if "__num_turns__" in batch_meta.field_names: + compute_data_metrics_fields.append("__num_turns__") + if "tool_call_counts" in batch_meta.field_names: + compute_data_metrics_fields.append("tool_call_counts") + compute_data_metrics_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=compute_data_metrics_fields, + task_name="compute_data_metrics", + **base_get_meta_kwargs, + ) + ) + compute_data_metrics_meta.reorder(balanced_idx) + metrics.update( + compute_data_metrics_decorated(batch=compute_data_metrics_meta, use_critic=self.use_critic) + ) + + compute_timing_metrics_fields = ["responses", "attention_mask"] + compute_timing_metrics_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=compute_timing_metrics_fields, + task_name="compute_timing_metrics", + **base_get_meta_kwargs, + ) + ) + compute_timing_metrics_meta.reorder(balanced_idx) + metrics.update( + compute_timing_metrics_decorated(batch=compute_timing_metrics_meta, timing_raw=timing_raw) + ) + + compute_throughout_metrics_meta = BatchMeta( + samples=[], + extra_info={"global_token_num": batch_meta.get_extra_info("global_token_num")}, + ) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update( + compute_throughout_metrics_decorated( + batch=compute_throughout_metrics_meta, timing_raw=timing_raw, n_gpus=n_gpus + ) + ) + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + # TODO: (TQ) support transfer queue + self.train_dataloader.sampler.update(batch=batch) + + asyncio.run(self.data_system_client.async_clear(self.global_steps - 1)) + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + # TODO: (TQ) support transfer queue + self.train_dataset.on_batch_end(batch=batch) diff --git a/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh b/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh new file mode 100644 index 00000000000..70b7e23976d --- /dev/null +++ b/recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh @@ -0,0 +1,63 @@ +set -x + +project_name='GRPO-Qwen3' +exp_name='GRPO-Qwen3-8B-npu' +gen_tp=2 +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-8B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +python3 -m recipe.transfer_queue.main_ppo \ + --config-name='transfer_queue_ppo_trainer' \ + algorithm.adv_estimator=grpo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.default_local_dir=${CKPTS_DIR} \ + trainer.device=npu \ + trainer.resume_mode=auto \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + ++actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \ + ++actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + trainer.val_before_train=False \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 \ + +trainer.num_global_batch=1 \ + +trainer.num_data_storage_units=2 \ + +trainer.num_data_controllers=1 \ No newline at end of file diff --git a/requirements_transferqueue.txt b/requirements_transferqueue.txt new file mode 100644 index 00000000000..8479d27bb21 --- /dev/null +++ b/requirements_transferqueue.txt @@ -0,0 +1,2 @@ +# requirements.txt records the full set of dependencies for development +git+https://github.com/TransferQueue/TransferQueue.git@68c04e7 diff --git a/setup.py b/setup.py index 4a86f035d09..b419f7f1af0 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,7 @@ ] TRL_REQUIRES = ["trl<=0.9.6"] MCORE_REQUIRES = ["mbridge"] +TRANSFERQUEUE_REQUIRES = ["TransferQueue @ git+https://github.com/TransferQueue/TransferQueue.git@68c04e7"] extras_require = { "test": TEST_REQUIRES, @@ -68,6 +69,7 @@ "sglang": SGLANG_REQUIRES, "trl": TRL_REQUIRES, "mcore": MCORE_REQUIRES, + "transferqueue": TRANSFERQUEUE_REQUIRES, } diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py index 8d3cfda27c8..f33bf704900 100644 --- a/tests/special_sanity/check_device_api_usage.py +++ b/tests/special_sanity/check_device_api_usage.py @@ -29,6 +29,7 @@ "recipe/spin/spin_trainer.py", # appear in default device_name "recipe/sppo/sppo_ray_trainer.py", # appear in default device_name "recipe/one_step_off_policy/ray_trainer.py", # appear in default device_name + "recipe/transfer_queue/ray_trainer.py", # appear in default device_name "verl/utils/profiler/nvtx_profile.py", # appear in NsightSystemsProfiler "verl/utils/kernel/linear_cross_entropy.py", # appear in nvidia nvtx "verl/utils/rendezvous/ray_backend.py", # appear in cupy importance diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index fe54edd51a0..5f452e8ec52 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -38,7 +38,12 @@ from verl.utils import hf_processor, hf_tokenizer from verl.utils.fs import copy_to_local from verl.utils.model import compute_position_id_with_mask -from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op +from verl.utils.rollout_trace import ( + RolloutTraceConfig, + rollout_trace_attr, + rollout_trace_op, +) +from verl.utils.transferqueue_utils import tqbridge from verl.workers.rollout.replica import TokenOutput, get_rollout_replica_class logger = logging.getLogger(__file__) @@ -421,6 +426,7 @@ def __init__( trace_config.get("token2text", False), ) + @tqbridge() async def generate_sequences(self, batch: DataProto) -> DataProto: """Generate sequences from agent loop. @@ -722,6 +728,18 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto: meta_info={"metrics": metrics, "reward_extra_keys": reward_extra_keys}, ) + def create_transferqueue_client(self, controller_infos, storage_infos, role): + """Create a client for data system(transfer queue).""" + from verl.single_controller.ray.base import get_random_string + from verl.utils.transferqueue_utils import create_transferqueue_client + + client_name = get_random_string(length=6) + create_transferqueue_client( + client_id=f"{role}_worker_{client_name}", + controller_infos=controller_infos, + storage_infos=storage_infos, + ) + @ray.remote class AgentLoopWorker(AgentLoopWorkerBase): diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index b246aaf7eec..9cab8f3ccdf 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -18,6 +18,7 @@ from verl.protocol import DataProtoFuture, _padding_size_key from verl.utils.py_functional import DynamicEnum +from verl.utils.transferqueue_utils import BatchMeta # here we add a magic number of avoid user-defined function already have this attribute MAGIC_ATTR = "attrs_3141562937" @@ -73,12 +74,12 @@ def _split_args_kwargs_data_proto(chunks, *args, **kwargs): splitted_args = [] for arg in args: - assert isinstance(arg, DataProto | DataProtoFuture) + assert isinstance(arg, DataProto | DataProtoFuture | BatchMeta) splitted_args.append(arg.chunk(chunks=chunks)) splitted_kwargs = {} for key, val in kwargs.items(): - assert isinstance(val, DataProto | DataProtoFuture) + assert isinstance(val, DataProto | DataProtoFuture | BatchMeta) splitted_kwargs[key] = val.chunk(chunks=chunks) return splitted_args, splitted_kwargs @@ -146,6 +147,8 @@ def _concat_data_proto_or_future(output: list): return DataProto.concat(output) elif isinstance(o, ray.ObjectRef): return DataProtoFuture.concat(output) + elif isinstance(o, BatchMeta): + return BatchMeta.concat(output) else: raise NotImplementedError @@ -265,7 +268,9 @@ def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output) from verl.protocol import DataProto for o in output: - assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}" + assert isinstance(o, DataProto | ray.ObjectRef | BatchMeta), ( + f"expecting {o} to be DataProto or BatchMeta, but got {type(o)}" + ) return _concat_data_proto_or_future(output) @@ -422,10 +427,14 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki A decorator that wraps the original function with distributed execution configuration. """ + from verl.utils.transferqueue_utils import tqbridge + _check_dispatch_mode(dispatch_mode=dispatch_mode) _check_execute_mode(execute_mode=execute_mode) def decorator(func): + func = tqbridge()(func) + @wraps(func) def inner(*args, **kwargs): if materialize_futures: diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 59ab27a9c39..2513c57f99c 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -130,6 +130,16 @@ def _query_collect_info(self, mesh_name: str): assert mesh_name in self.__collect_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}" return self.__collect_dp_rank[mesh_name] + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True) + def create_transferqueue_client(self, controller_infos, storage_infos, role="train"): + from verl.utils.transferqueue_utils import create_transferqueue_client + + create_transferqueue_client( + client_id=f"{role}_worker_{self.rank}", + controller_infos=controller_infos, + storage_infos=storage_infos, + ) + @classmethod def env_keys(cls): """The keys of the environment variables that are used to configure the Worker.""" diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index fc98da09fa4..82ba944f1a8 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -542,6 +542,8 @@ global_profiler: context: all stacks: all kw_args: {} +transfer_queue: + enable: false ray_kwargs: ray_init: num_cpus: null diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index f366b0d4579..9cd652d3099 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -526,6 +526,8 @@ global_profiler: context: all stacks: all kw_args: {} +transfer_queue: + enable: false ray_kwargs: ray_init: num_cpus: null diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 6bdaffe838c..8a23a3b72d4 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -189,6 +189,12 @@ global_profiler: # devices, record_context etc. kw_args: {} +# configs for TransferQueue +transfer_queue: + + # Whether to enable transfer queue + enable: False + ray_kwargs: ray_init: num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index b1724e4de48..72c97e4a7cd 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -317,6 +317,12 @@ global_profiler: # devices, record_context etc. kw_args: {} +# configs for TransferQueue +transfer_queue: + + # Whether to enable transfer queue + enable: False + # configs related to ray ray_kwargs: diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index b9a357dcfc2..83666084467 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -63,11 +63,13 @@ def run_ppo(config, task_runner_class=None) -> None: runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + if config.transfer_queue.enable: + ray_init_kwargs["TRANSFER_QUEUE_ENABLE"] = "1" print(f"ray init kwargs: {ray_init_kwargs}") ray.init(**OmegaConf.to_container(ray_init_kwargs)) if task_runner_class is None: - task_runner_class = TaskRunner + task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head # Create a remote instance of the TaskRunner class, and # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete @@ -95,7 +97,6 @@ def run_ppo(config, task_runner_class=None) -> None: ray.timeline(filename=timeline_json_file) -@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: """Ray remote class for executing distributed PPO training tasks. diff --git a/verl/trainer/ppo/reward.py b/verl/trainer/ppo/reward.py index 5be7a68a330..828206affdd 100644 --- a/verl/trainer/ppo/reward.py +++ b/verl/trainer/ppo/reward.py @@ -26,6 +26,7 @@ from verl import DataProto from verl.utils.reward_score import default_compute_score +from verl.utils.transferqueue_utils import tqbridge from verl.workers.reward_manager import get_reward_manager_cls from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn @@ -152,6 +153,7 @@ def load_reward_manager( ) +@tqbridge(put_data=False) def compute_reward(data: DataProto, reward_fn: AbstractRewardManager) -> tuple[torch.Tensor, dict[str, Any]]: """ Compute reward for a batch of data. diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py new file mode 100644 index 00000000000..27160571ef3 --- /dev/null +++ b/verl/utils/transferqueue_utils.py @@ -0,0 +1,211 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import inspect +import os +import threading +from functools import wraps +from typing import Any, Callable + +from tensordict import TensorDict + +try: + from transfer_queue import ( + AsyncTransferQueueClient, + BatchMeta, + ZMQServerInfo, + ) + +except ImportError: + # TODO: Use a hacky workaround for ImportError since + # transfer_queue isn't a default verl dependency. + class BatchMeta: + pass + + +from verl.protocol import DataProto + +_TRANSFER_QUEUE_CLIENT = None +_VAL_TRANSFER_QUEUE_CLIENT = None + +is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False) + + +def create_transferqueue_client( + client_id: str, + controller_infos: dict[Any, "ZMQServerInfo"], + storage_infos: dict[Any, "ZMQServerInfo"], +) -> None: + global _TRANSFER_QUEUE_CLIENT + global _VAL_TRANSFER_QUEUE_CLIENT + if "val" in client_id: + _VAL_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) + else: + _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_infos, storage_infos) + + +def get_transferqueue_client() -> "AsyncTransferQueueClient": + return _TRANSFER_QUEUE_CLIENT + + +def get_val_transferqueue_client() -> "AsyncTransferQueueClient": + return _VAL_TRANSFER_QUEUE_CLIENT + + +def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any: + # Use a temporary event loop in a new thread because event + # loop may already exist in server mode + tmp_event_loop = asyncio.new_event_loop() + thread = threading.Thread( + target=tmp_event_loop.run_forever, + name="batchmeta dataproto converter", + daemon=True, + ) + + def run_coroutine(coroutine): + if not thread.is_alive(): + thread.start() + future = asyncio.run_coroutine_threadsafe(coroutine, tmp_event_loop) + return future.result() + + async def stop_loop(): + tmp_event_loop.stop() + + try: + return run_coroutine(async_func(*args, **kwargs)) + finally: + if thread.is_alive(): + asyncio.run_coroutine_threadsafe(stop_loop(), tmp_event_loop) + thread.join() + + +def _find_batchmeta(*args, **kwargs): + for arg in args: + if isinstance(arg, BatchMeta): + return arg + for v in kwargs.values(): + if isinstance(v, BatchMeta): + return v + return None + + +async def _async_batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: + if batchmeta.samples == [] or batchmeta.samples is None: + return DataProto( + batch=TensorDict({}, batch_size=(0,)), + non_tensor_batch={}, + meta_info=batchmeta.extra_info.copy(), + ) + + if batchmeta.extra_info.get("validate", False): + tensordict = await _VAL_TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) + else: + tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) + return DataProto.from_tensordict(tensordict, meta_info=batchmeta.extra_info.copy()) + + +def _batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: + return _run_async_in_temp_loop(_async_batchmeta_to_dataproto, batchmeta) + + +async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") -> None: + for k, v in output.meta_info.items(): + batchmeta.set_extra_info(k, v) + + if len(output) > 0: + tensordict = output.to_tensordict() + # pop meta_info + for key in output.meta_info.keys(): + tensordict.pop(key) + batchmeta.add_fields(tensordict) + if batchmeta.extra_info.get("validate", False): + await _VAL_TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) + else: + await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) + + +def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") -> None: + _run_async_in_temp_loop(_async_update_batchmeta_with_output, output, batchmeta) + + +def tqbridge(put_data: bool = True): + """ "Creates a decorator for bridging BatchMeta and DataProto. + + This decorator automatically handles conversions between `BatchMeta` and + `DataProto` in function parameters, and decides whether to sync function + output back to `BatchMeta` based on configuration(`put_data`). It supports + both synchronous and asynchronous functions (async def), and can control + whether to enable enhanced logic via the global `HAS_TQ` variable (when disabled, + simply calls the original function as-is). + + Args: + put_data: Whether put the DataProto into Storage after func return. + If True, after function execution, the output result will be + updated to `BatchMeta` and `BatchMeta` will be returned; + If False, the function output result will be returned directly. + Defaults to True. + + Returns: + A decorator function used to decorate target functions (synchronous or asynchronous). + """ + + def decorator(func): + @wraps(func) + def inner(*args, **kwargs): + batchmeta = _find_batchmeta(*args, **kwargs) + if batchmeta is None: + return func(*args, **kwargs) + else: + args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] + kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()} + output = func(*args, **kwargs) + if put_data: + _update_batchmeta_with_output(output, batchmeta) + return batchmeta + else: + return output + + @wraps(func) + async def async_inner(*args, **kwargs): + batchmeta = _find_batchmeta(*args, **kwargs) + if batchmeta is None: + return await func(*args, **kwargs) + else: + args = [await _async_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] + kwargs = { + k: await _async_batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v + for k, v in kwargs.items() + } + output = await func(*args, **kwargs) + if put_data: + await _async_update_batchmeta_with_output(output, batchmeta) + return batchmeta + return output + + @wraps(func) + def dummy_inner(*args, **kwargs): + return func(*args, **kwargs) + + @wraps(func) + async def dummy_async_inner(*args, **kwargs): + return await func(*args, **kwargs) + + wrapper_inner = inner if is_transferqueue_enabled else dummy_inner + wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner + + wrapper = wrapper_async_inner if inspect.iscoroutinefunction(func) else wrapper_inner + return wrapper + + return decorator