Skip to content

Commit 8bbb2e9

Browse files
committed
[data, trainer] feat: add support for limiting samples from dataset
e.g.: `filter_overlong_prompts` can be very expensive and it will be good to add support to limit the sample size before we do this when the dataset is very large. Signed-off-by: Hollow Man <[email protected]>
1 parent 4da0d3d commit 8bbb2e9

23 files changed

+224
-24
lines changed

docs/examples/config.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ Data
1717
tokenizer: null
1818
train_files: ~/data/rlhf/gsm8k/train.parquet
1919
val_files: ~/data/rlhf/gsm8k/test.parquet
20+
train_max_samples: -1 # set to -1 to use full dataset
21+
val_max_samples: -1 # set to -1 to use full dataset
2022
prompt_key: prompt
2123
max_prompt_length: 512
2224
max_response_length: 512
@@ -41,6 +43,10 @@ Data
4143
HDFS path to local path.
4244
- ``data.val_files``: Validation parquet. Can be a list or a single
4345
file.
46+
- ``data.train_max_samples``: Maximum number of samples to use from the
47+
training dataset. Set to -1 to use the full dataset.
48+
- ``data.val_max_samples``: Maximum number of samples to use from the
49+
validation dataset. Set to -1 to use the full dataset.
4450
- ``data.prompt_key``: The field in the dataset where the prompt is
4551
located. Default is 'prompt'.
4652
- ``data.max_prompt_length``: Maximum prompt length. All prompts will be

examples/split_placement/config/ppo_trainer_split.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ data:
1212
tokenizer: null
1313
train_files: ~/data/rlhf/gsm8k/train.parquet
1414
val_files: ~/data/rlhf/gsm8k/test.parquet
15+
train_max_samples: -1 # set to -1 to use full dataset
16+
val_max_samples: -1 # set to -1 to use full dataset
1517
prompt_key: prompt
1618
max_prompt_length: 512
1719
max_response_length: 512

recipe/entropy/main_entropy.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,16 @@ def run(self, config):
162162

163163
from verl.utils.dataset.rl_dataset import collate_fn
164164

165-
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
166-
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
165+
train_dataset = create_rl_dataset(
166+
config.data.train_files,
167+
config.data,
168+
tokenizer,
169+
processor,
170+
max_samples=config.data.get("train_max_samples", -1),
171+
)
172+
val_dataset = create_rl_dataset(
173+
config.data.val_files, config.data, tokenizer, processor, max_samples=config.data.get("val_max_samples", -1)
174+
)
167175
train_sampler = create_rl_sampler(config.data, train_dataset)
168176
trainer = RayEntropyTrainer(
169177
config=config,
@@ -183,7 +191,7 @@ def run(self, config):
183191
trainer.fit()
184192

185193

186-
def create_rl_dataset(data_paths, data_config, tokenizer, processor):
194+
def create_rl_dataset(data_paths, data_config, tokenizer, processor, max_samples: int = -1):
187195
"""Create a dataset.
188196
189197
Arguments:
@@ -216,6 +224,7 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor):
216224
tokenizer=tokenizer,
217225
processor=processor,
218226
config=data_config,
227+
max_samples=max_samples,
219228
)
220229

221230
return dataset

recipe/one_step_off_policy/main_ppo.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,16 @@ def run(self, config):
212212
from verl.utils.dataset.rl_dataset import collate_fn
213213

214214
# Create training and validation datasets.
215-
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
216-
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
215+
train_dataset = create_rl_dataset(
216+
config.data.train_files,
217+
config.data,
218+
tokenizer,
219+
processor,
220+
max_samples=config.data.get("train_max_samples", -1),
221+
)
222+
val_dataset = create_rl_dataset(
223+
config.data.val_files, config.data, tokenizer, processor, max_samples=config.data.get("val_max_samples", -1)
224+
)
217225
train_sampler = create_rl_sampler(config.data, train_dataset)
218226

219227
# Initialize the PPO trainer.

recipe/spin/spin_trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,19 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
393393

394394
if train_dataset is None:
395395
train_dataset = create_rl_dataset(
396-
self.config.data.train_files, self.config.data, self.tokenizer, self.processor
396+
self.config.data.train_files,
397+
self.config.data,
398+
self.tokenizer,
399+
self.processor,
400+
max_samples=self.config.data.get("train_max_samples", -1),
397401
)
398402
if val_dataset is None:
399403
val_dataset = create_rl_dataset(
400-
self.config.data.val_files, self.config.data, self.tokenizer, self.processor
404+
self.config.data.val_files,
405+
self.config.data,
406+
self.tokenizer,
407+
self.processor,
408+
max_samples=self.config.data.get("val_max_samples", -1),
401409
)
402410
self.train_dataset, self.val_dataset = train_dataset, val_dataset
403411

tests/special_e2e/sft/test_sp_loss_match.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,12 @@ def create_trainer(config):
112112

113113
local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
114114
tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
115-
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
116-
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
115+
train_dataset = create_sft_dataset(
116+
config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1)
117+
)
118+
val_dataset = create_sft_dataset(
119+
config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1)
120+
)
117121

118122
return FSDPSFTTrainer(
119123
config=config,

tests/trainer/config/legacy_ppo_megatron_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ data:
22
tokenizer: null
33
train_files: ~/data/rlhf/gsm8k/train.parquet
44
val_files: ~/data/rlhf/gsm8k/test.parquet
5+
train_max_samples: -1 # set to -1 to use full dataset
6+
val_max_samples: -1 # set to -1 to use full dataset
57
prompt_key: prompt
68
reward_fn_key: data_source
79
max_prompt_length: 512

tests/trainer/config/legacy_ppo_trainer.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ data:
2222
# Validation parquet. Can be a list or a single file.
2323
val_files: ~/data/rlhf/gsm8k/test.parquet
2424

25+
# Maximum sample length to be used.
26+
# Set to -1 to use full dataset, otherwise, randomly
27+
# select the specified number of samples from train dataset
28+
train_max_samples: -1
29+
30+
# Maximum sample length to be used.
31+
# Set to -1 to use full dataset, otherwise, randomly
32+
# select the specified number of samples from val dataset
33+
val_max_samples: -1
34+
2535
# The field in the dataset where the prompt is located. Default is 'prompt'.
2636
prompt_key: prompt
2737

tests/utils/dataset/test_rl_dataset_on_cpu.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,24 @@ def test_rl_dataset():
6666
print(f"\n\noutput: {output}")
6767

6868

69+
def test_rl_dataset_with_max_samples():
70+
from verl.utils import hf_tokenizer
71+
from verl.utils.dataset.rl_dataset import RLHFDataset
72+
73+
tokenizer = hf_tokenizer("deepseek-ai/deepseek-coder-1.3b-instruct")
74+
local_path = get_gsm8k_data()
75+
config = OmegaConf.create(
76+
{
77+
"prompt_key": "prompt",
78+
"max_prompt_length": 256,
79+
"filter_overlong_prompts": True,
80+
"filter_overlong_prompts_workers": 2,
81+
}
82+
)
83+
dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config, max_samples=5)
84+
assert len(dataset) == 5
85+
86+
6987
def test_image_rl_data():
7088
from verl.utils import hf_processor, hf_tokenizer
7189
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn

tests/utils/dataset/test_sft_dataset_on_cpu.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,26 @@ def test_sft_dataset():
7272
output = tokenizer.batch_decode([data])[0]
7373
assert len(output) > 1
7474
assert isinstance(output, str)
75+
76+
77+
def test_sft_dataset_with_max_samples():
78+
tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")
79+
local_path = get_gsm8k_data()
80+
from omegaconf import OmegaConf
81+
82+
dataset = SFTDataset(
83+
parquet_files=local_path,
84+
tokenizer=tokenizer,
85+
config=OmegaConf.create(
86+
{
87+
"prompt_key": "extra_info",
88+
"prompt_dict_keys": ["question"],
89+
"response_key": "extra_info",
90+
"response_dict_keys": ["answer"],
91+
"max_length": 512,
92+
"max_samples": 5,
93+
}
94+
),
95+
)
96+
97+
assert len(dataset) == 5

0 commit comments

Comments
 (0)