Skip to content

Commit d6d0b13

Browse files
authored
[Feature] RayLLMCollector.sync_iter (#3015)
1 parent 7ee5248 commit d6d0b13

File tree

6 files changed

+48
-4
lines changed

6 files changed

+48
-4
lines changed

sota-implementations/grpo/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,11 @@ for step in range(total_steps):
107107

108108
Key differences:
109109
1. **Data Collection**:
110-
- Sync: Data collection and optimization happen sequentially
110+
- Sync: Data collection and optimization happen sequentially.
111+
112+
*Note*: The `train.sync_iter=False` argument can be used to collect data whilst optimizing. In this context, the
113+
maximum policy age will be 1. If `train.sync_iter=True` (default), the maximum policy age is `0`.
114+
111115
- Async: Data collection runs in background while optimization happens
112116

113117
2. **Buffer Size**:

sota-implementations/grpo/config/mode/async.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ train:
99
buffer_size: 128
1010
# Update policy weights every N steps - can be set to any positive integer in async mode
1111
weight_update_frequency: 10
12+
# Sync the collector between iterations. Deactivated when async.
13+
sync_iter:

sota-implementations/grpo/config/mode/sync.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ train:
99
buffer_size:
1010
# Update policy weights every N steps - must be left empty in sync mode
1111
weight_update_frequency:
12+
# Sync the collector between iterations. Not syncing means that the collector will collect the next batch of data in between yielding.
13+
# When sync_iter=True, the maximuum policy age is 0. When sync_iter=False, the maximuum policy age is 1.
14+
sync_iter: true

sota-implementations/grpo/grpo-async.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,8 @@ def main(cfg):
465465
)
466466
torchrl_logger.info(f"Starting collector with {collector_config=}")
467467

468+
if cfg.train.sync_iter is not None:
469+
raise ValueError("sync_iter is not supported in async mode.")
468470
collector = RayLLMCollector(
469471
env=partial(make_env, cfg, devices=device_config["ref_model_devices"]),
470472
policy=inference_policy,

sota-implementations/grpo/grpo-sync.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ def main(cfg):
483483
# The ref model will be instantiated within the collector, so we only need to allocate the number of devices for the inference model
484484
cfg.ref_model.num_devices
485485
)
486+
collector_config["num_cpus"] = cfg.ray.collector_config.get("num_cpus", 1)
486487
torchrl_logger.info(f"Starting collector with {collector_config=}")
487488

488489
collector = RayLLMCollector(
@@ -495,6 +496,7 @@ def main(cfg):
495496
weight_updater=None, # We'll create this after getting the remote LLM
496497
track_policy_version=True,
497498
remote_config=collector_config,
499+
sync_iter=cfg.train.sync_iter,
498500
verbose=True,
499501
)
500502
# Ensure collector is initialized by calling a method that will block until ready

torchrl/collectors/llm/ray_collector.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import copy
8+
79
import warnings
810
from typing import Any, Callable, Iterator
911

@@ -55,6 +57,19 @@ class RayLLMCollector(LLMCollector):
5557
or its subclass, responsible for updating the policy weights on remote inference workers.
5658
ray_init_config (dict[str, Any], optional): keyword arguments to pass to ray.init().
5759
remote_config (dict[str, Any], optional): keyword arguments to pass to cls.as_remote().
60+
sync_iter (bool, optional): if `True`, items yeilded by the collector will be synced to the local process.
61+
If `False`, the collector will collect the next batch of data in between yielding.
62+
This has no effect when data is collected through the :meth:`start` method.
63+
For example:
64+
65+
>>> collector = RayLLMCollector(..., sync_iter=True)
66+
>>> for data in collector: # blocking
67+
... # expensive operation - collector is idle
68+
>>> collector = RayLLMCollector(..., sync_iter=False)
69+
>>> for data in collector: # non-blocking
70+
... # expensive operation - collector is collecting data
71+
72+
Defaults to `True`.
5873
verbose (bool, optional): if ``True``, the collector will print progress information.
5974
Defaults to `False`.
6075
"""
@@ -81,6 +96,7 @@ def __init__(
8196
ray_init_config: dict[str, Any] | None = None,
8297
remote_config: dict[str, Any] | None = None,
8398
track_policy_version: bool | PolicyVersion = False,
99+
sync_iter: bool = True,
84100
verbose: bool = False,
85101
) -> None:
86102
if not _has_ray:
@@ -93,8 +109,11 @@ def __init__(
93109

94110
ray_init_config = DEFAULT_RAY_INIT_CONFIG
95111
ray.init(**ray_init_config)
96-
112+
if not sync_iter:
113+
remote_config = copy.copy(remote_config)
114+
remote_config.setdefault("max_concurrency", 2)
97115
remote_cls = LLMCollector.as_remote(remote_config).remote
116+
self.sync_iter = sync_iter
98117
self._collector = remote_cls(
99118
env=env,
100119
policy=policy,
@@ -113,19 +132,31 @@ def __init__(
113132
verbose=verbose,
114133
)
115134

135+
def _next_remote(self) -> None:
136+
return self._collector.next.remote()
137+
116138
def next(self) -> None:
117139
"""Get the next batch of data from the collector.
118140
119141
Returns:
120142
None as the data is written directly to the replay buffer.
121143
"""
122-
return ray.get(self._collector.next.remote())
144+
return ray.get(self._next_remote())
123145

124146
def __iter__(self) -> Iterator[None]:
125147
"""Returns an iterator that yields None as the collector writes directly to the replay buffer."""
148+
if not self.sync_iter:
149+
future = self._next_remote()
150+
else:
151+
future = None
126152
while True:
127153
try:
128-
yield self.next()
154+
if self.sync_iter:
155+
yield self.next()
156+
else:
157+
result = ray.get(future)
158+
future = self._next_remote()
159+
yield result
129160
except StopIteration:
130161
break
131162

0 commit comments

Comments
 (0)