Skip to content

Commit e36d562

Browse files
authored
[BugFix] Fix cuda cache empty in GRPO scripts (#3016)
1 parent d6d0b13 commit e36d562

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

sota-implementations/grpo/grpo-async.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ def train(
354354
with timeit("update_policy_weights"):
355355
torchrl_logger.info("Updating policy weights...")
356356
weight_updater.push_weights(policy_training)
357-
torch.cuda.empty_cache()
357+
# TODO: do we need this? Does it interfere with other processes?
358+
# torch.cuda.empty_cache()
358359
gc.collect()
359360

360361
# Checkpointing disabled to prevent disk space issues
@@ -380,7 +381,8 @@ def train(
380381

381382
# Clear memory
382383
del loss_val
383-
torch.cuda.empty_cache()
384+
# TODO: do we need this? Does it interfere with other processes?
385+
# torch.cuda.empty_cache()
384386
gc.collect()
385387

386388
pbar.close()

sota-implementations/grpo/grpo-sync.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ def train(
288288

289289
# Clear memory
290290
del loss_val
291-
torch.cuda.empty_cache()
291+
# TODO: do we need this? Does it interfere with other processes?
292+
# torch.cuda.empty_cache()
292293
gc.collect()
293294

294295
# Update metrics
@@ -387,7 +388,8 @@ def train(
387388
with timeit("update_policy_weights"):
388389
torchrl_logger.info("Updating policy weights...")
389390
weight_updater.push_weights(policy_training)
390-
torch.cuda.empty_cache()
391+
# TODO: do we need this? Does it interfere with other processes?
392+
# torch.cuda.empty_cache()
391393
gc.collect()
392394

393395
timeit.print(prefix="timeit")

torchrl/collectors/llm/ray_collector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class RayLLMCollector(LLMCollector):
6969
>>> for data in collector: # non-blocking
7070
... # expensive operation - collector is collecting data
7171
72+
This is somehwat equivalent to using :class:`~torchrl.collectors.MultiSyncDataCollector` (`sync_iter=True`) or
73+
:class:`~torchrl.collectors.MultiAsyncDataCollector` (`sync_iter=False`).
7274
Defaults to `True`.
7375
verbose (bool, optional): if ``True``, the collector will print progress information.
7476
Defaults to `False`.

0 commit comments

Comments
 (0)