Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Data
return_raw_chat: False
return_full_prompt: False
shuffle: True
seed: 42
filter_overlong_prompts: False
filter_overlong_prompts_workers: 1
truncation: error
Expand Down Expand Up @@ -60,6 +61,8 @@ Data
without applying chat template.
- ``data.return_full_prompt``: Whether to return the full prompt with chat template
- ``data.shuffle``: Whether to shuffle the data in the dataloader.
- ``data.seed``: An integer seed to use when shuffling the data. If not set or set to
`null`, the data shuffling will not be seeded, resulting in a different data order on each run.
- ``data.filter_overlong_prompts``: Default don't filter.
- ``data.filter_overlong_prompts_workers``: For large-scale dataset, filtering
overlong prompts could be timeconsuming. You cat set the ``filter_overlong_prompts_workers``
Expand Down
1 change: 1 addition & 0 deletions examples/split_placement/config/ppo_trainer_split.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ data:
return_raw_chat: False
return_full_prompt: False
shuffle: True
seed: 42

actor_rollout_ref:
hybrid_engine: True
Expand Down
4 changes: 3 additions & 1 deletion recipe/entropy/main_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def create_rl_sampler(data_config, dataset):
# use sampler for better ckpt resume
if data_config.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(data_config.get("seed", 1))
seed = data_config.get("seed")
if seed is not None:
train_dataloader_generator.manual_seed(seed)
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=dataset)
Expand Down
4 changes: 3 additions & 1 deletion recipe/prime/prime_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def _create_dataloader(self, *args, **kwargs):
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.get("seed", 1))
seed = self.config.data.get("seed")
if seed is not None:
train_dataloader_generator.manual_seed(seed)
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/config/legacy_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ data:
return_raw_chat: False
return_full_prompt: False
shuffle: True
seed: null # An integer seed to use when shuffling the data. If not set or set to `null`, the data shuffling will not be seeded, resulting in a different data order on each run.
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
filter_overlong_prompts_workers: 1
truncation: error
Expand Down
4 changes: 4 additions & 0 deletions tests/trainer/config/legacy_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ data:
# Whether to shuffle the data in the dataloader.
shuffle: True

# An integer seed to use when shuffling the data. If not set or set to
# `null`, the data shuffling will not be seeded, resulting in a different data order on each run.
seed: null

# num dataloader workers
dataloader_num_workers: 8

Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ data:
return_raw_chat: false
return_full_prompt: false
shuffle: true
seed: null
dataloader_num_workers: 8
validation_shuffle: false
filter_overlong_prompts: false
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ data:
return_raw_chat: false
return_full_prompt: false
shuffle: true
seed: null
dataloader_num_workers: 8
validation_shuffle: false
filter_overlong_prompts: false
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/data/legacy_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ return_full_prompt: False
# Whether to shuffle the data in the dataloader.
shuffle: True

# Seed to use when shuffling the data
seed: null

# num dataloader workers
dataloader_num_workers: 8

Expand Down
4 changes: 3 additions & 1 deletion verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,9 @@ def create_rl_sampler(data_config, dataset):
# If shuffling is enabled in the data configuration, create a random sampler.
elif data_config.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(data_config.get("seed", 1))
seed = data_config.get("seed")
if seed is not None:
train_dataloader_generator.manual_seed(seed)
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
else:
# If shuffling is disabled, use a sequential sampler to iterate through the dataset in order.
Expand Down
Loading