From b0cf1b601b6912f22721bb38f16ebc80d029ae1a Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Sat, 18 Oct 2025 16:49:37 +0300 Subject: [PATCH] [data, trainer] feat: add support for limiting samples from dataset e.g.: For RLHFDataset, `filter_overlong_prompts` can be very expensive and it will be good to add support to limit the sample size before we do this when the dataset is very large. Also add support for other kinds of datasets for unification. --- docs/examples/config.rst | 6 +++++ .../config/ppo_trainer_split.yaml | 2 ++ recipe/entropy/main_entropy.py | 15 +++++++++--- recipe/one_step_off_policy/main_ppo.py | 12 ++++++++-- recipe/spin/spin_trainer.py | 12 ++++++++-- tests/special_e2e/sft/test_sp_loss_match.py | 8 +++++-- .../config/legacy_ppo_megatron_trainer.yaml | 2 ++ tests/trainer/config/legacy_ppo_trainer.yaml | 10 ++++++++ tests/utils/dataset/test_rl_dataset_on_cpu.py | 19 +++++++++++++++ .../utils/dataset/test_sft_dataset_on_cpu.py | 23 +++++++++++++++++++ .../_generated_ppo_megatron_trainer.yaml | 2 ++ .../config/_generated_ppo_trainer.yaml | 2 ++ verl/trainer/config/data/legacy_data.yaml | 10 ++++++++ verl/trainer/config/sft_trainer.yaml | 2 ++ verl/trainer/config/sft_trainer_engine.yaml | 2 ++ verl/trainer/fsdp_sft_trainer.py | 12 ++++++---- verl/trainer/main_ppo.py | 21 ++++++++++++++--- verl/trainer/ppo/ray_trainer.py | 12 ++++++++-- verl/trainer/sft_trainer.py | 12 ++++++---- verl/utils/dataset/multiturn_sft_dataset.py | 18 ++++++++++++++- verl/utils/dataset/rl_dataset.py | 15 ++++++++++++ verl/utils/dataset/rm_dataset.py | 22 ++++++++++++++++++ verl/utils/dataset/sft_dataset.py | 20 +++++++++++++++- 23 files changed, 235 insertions(+), 24 deletions(-) diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 05456a54b53..2919e00bb88 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -17,6 +17,8 @@ Data tokenizer: null train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 # set to -1 to use full dataset + val_max_samples: -1 # set to -1 to use full dataset prompt_key: prompt max_prompt_length: 512 max_response_length: 512 @@ -42,6 +44,10 @@ Data HDFS path to local path. - ``data.val_files``: Validation parquet. Can be a list or a single file. +- ``data.train_max_samples``: Maximum number of samples to use from the + training dataset. Set to -1 to use the full dataset. +- ``data.val_max_samples``: Maximum number of samples to use from the + validation dataset. Set to -1 to use the full dataset. - ``data.prompt_key``: The field in the dataset where the prompt is located. Default is 'prompt'. - ``data.max_prompt_length``: Maximum prompt length. All prompts will be diff --git a/examples/split_placement/config/ppo_trainer_split.yaml b/examples/split_placement/config/ppo_trainer_split.yaml index 2c2ae6199a4..f602f799c7c 100644 --- a/examples/split_placement/config/ppo_trainer_split.yaml +++ b/examples/split_placement/config/ppo_trainer_split.yaml @@ -12,6 +12,8 @@ data: tokenizer: null train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 # set to -1 to use full dataset + val_max_samples: -1 # set to -1 to use full dataset prompt_key: prompt max_prompt_length: 512 max_response_length: 512 diff --git a/recipe/entropy/main_entropy.py b/recipe/entropy/main_entropy.py index d7a12d8dd15..81da54471f1 100644 --- a/recipe/entropy/main_entropy.py +++ b/recipe/entropy/main_entropy.py @@ -162,8 +162,16 @@ def run(self, config): from verl.utils.dataset.rl_dataset import collate_fn - train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) - val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, config.data, tokenizer, processor, max_samples=config.data.get("val_max_samples", -1) + ) train_sampler = create_rl_sampler(config.data, train_dataset) trainer = RayEntropyTrainer( config=config, @@ -183,7 +191,7 @@ def run(self, config): trainer.fit() -def create_rl_dataset(data_paths, data_config, tokenizer, processor): +def create_rl_dataset(data_paths, data_config, tokenizer, processor, max_samples: int = -1): """Create a dataset. Arguments: @@ -216,6 +224,7 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor): tokenizer=tokenizer, processor=processor, config=data_config, + max_samples=max_samples, ) return dataset diff --git a/recipe/one_step_off_policy/main_ppo.py b/recipe/one_step_off_policy/main_ppo.py index 344fe4b9f0c..ec87e924ce3 100644 --- a/recipe/one_step_off_policy/main_ppo.py +++ b/recipe/one_step_off_policy/main_ppo.py @@ -212,8 +212,16 @@ def run(self, config): from verl.utils.dataset.rl_dataset import collate_fn # Create training and validation datasets. - train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) - val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, config.data, tokenizer, processor, max_samples=config.data.get("val_max_samples", -1) + ) train_sampler = create_rl_sampler(config.data, train_dataset) # Initialize the PPO trainer. diff --git a/recipe/spin/spin_trainer.py b/recipe/spin/spin_trainer.py index d312e7e4841..3f69d831daa 100644 --- a/recipe/spin/spin_trainer.py +++ b/recipe/spin/spin_trainer.py @@ -393,11 +393,19 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl if train_dataset is None: train_dataset = create_rl_dataset( - self.config.data.train_files, self.config.data, self.tokenizer, self.processor + self.config.data.train_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("train_max_samples", -1), ) if val_dataset is None: val_dataset = create_rl_dataset( - self.config.data.val_files, self.config.data, self.tokenizer, self.processor + self.config.data.val_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("val_max_samples", -1), ) self.train_dataset, self.val_dataset = train_dataset, val_dataset diff --git a/tests/special_e2e/sft/test_sp_loss_match.py b/tests/special_e2e/sft/test_sp_loss_match.py index 4dc0cbdae5a..5d8e59e721d 100644 --- a/tests/special_e2e/sft/test_sp_loss_match.py +++ b/tests/special_e2e/sft/test_sp_loss_match.py @@ -112,8 +112,12 @@ def create_trainer(config): local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) - train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) - val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) + train_dataset = create_sft_dataset( + config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1) + ) + val_dataset = create_sft_dataset( + config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1) + ) return FSDPSFTTrainer( config=config, diff --git a/tests/trainer/config/legacy_ppo_megatron_trainer.yaml b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml index 47e91377dd5..60cc9d0cee0 100644 --- a/tests/trainer/config/legacy_ppo_megatron_trainer.yaml +++ b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml @@ -2,6 +2,8 @@ data: tokenizer: null train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 # set to -1 to use full dataset + val_max_samples: -1 # set to -1 to use full dataset prompt_key: prompt reward_fn_key: data_source max_prompt_length: 512 diff --git a/tests/trainer/config/legacy_ppo_trainer.yaml b/tests/trainer/config/legacy_ppo_trainer.yaml index 4b8f5957fa8..cd77323df73 100644 --- a/tests/trainer/config/legacy_ppo_trainer.yaml +++ b/tests/trainer/config/legacy_ppo_trainer.yaml @@ -22,6 +22,16 @@ data: # Validation parquet. Can be a list or a single file. val_files: ~/data/rlhf/gsm8k/test.parquet + # Maximum sample length to be used. + # Set to -1 to use full dataset, otherwise, randomly + # select the specified number of samples from train dataset + train_max_samples: -1 + + # Maximum sample length to be used. + # Set to -1 to use full dataset, otherwise, randomly + # select the specified number of samples from val dataset + val_max_samples: -1 + # The field in the dataset where the prompt is located. Default is 'prompt'. prompt_key: prompt diff --git a/tests/utils/dataset/test_rl_dataset_on_cpu.py b/tests/utils/dataset/test_rl_dataset_on_cpu.py index 391e89a94d5..009172158ef 100644 --- a/tests/utils/dataset/test_rl_dataset_on_cpu.py +++ b/tests/utils/dataset/test_rl_dataset_on_cpu.py @@ -66,6 +66,25 @@ def test_rl_dataset(): print(f"\n\noutput: {output}") +def test_rl_dataset_with_max_samples(): + from verl.utils import hf_tokenizer + from verl.utils.dataset.rl_dataset import RLHFDataset + + tokenizer = hf_tokenizer("deepseek-ai/deepseek-coder-1.3b-instruct") + local_path = get_gsm8k_data() + config = OmegaConf.create( + { + "prompt_key": "prompt", + "max_prompt_length": 256, + "filter_overlong_prompts": True, + "filter_overlong_prompts_workers": 2, + "max_samples": 5, + } + ) + dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config, max_samples=5) + assert len(dataset) == 5 + + def test_image_rl_data(): from verl.utils import hf_processor, hf_tokenizer from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn diff --git a/tests/utils/dataset/test_sft_dataset_on_cpu.py b/tests/utils/dataset/test_sft_dataset_on_cpu.py index 680fce45a2a..81a6977d965 100644 --- a/tests/utils/dataset/test_sft_dataset_on_cpu.py +++ b/tests/utils/dataset/test_sft_dataset_on_cpu.py @@ -72,3 +72,26 @@ def test_sft_dataset(): output = tokenizer.batch_decode([data])[0] assert len(output) > 1 assert isinstance(output, str) + + +def test_sft_dataset_with_max_samples(): + tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct") + local_path = get_gsm8k_data() + from omegaconf import OmegaConf + + dataset = SFTDataset( + parquet_files=local_path, + tokenizer=tokenizer, + config=OmegaConf.create( + { + "prompt_key": "extra_info", + "prompt_dict_keys": ["question"], + "response_key": "extra_info", + "response_dict_keys": ["answer"], + "max_length": 512, + } + ), + max_samples=5, + ) + + assert len(dataset) == 5 diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 579595fd71a..6231be3ce52 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -276,6 +276,8 @@ data: use_shm: false train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 prompt_key: prompt reward_fn_key: data_source max_prompt_length: 512 diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 8a71e0637f5..06fa856b792 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -272,6 +272,8 @@ data: use_shm: false train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 prompt_key: prompt reward_fn_key: data_source max_prompt_length: 512 diff --git a/verl/trainer/config/data/legacy_data.yaml b/verl/trainer/config/data/legacy_data.yaml index 8f5f334dcde..a6e69b5dcc8 100644 --- a/verl/trainer/config/data/legacy_data.yaml +++ b/verl/trainer/config/data/legacy_data.yaml @@ -13,6 +13,16 @@ train_files: ~/data/rlhf/gsm8k/train.parquet # Validation parquet. Can be a list or a single file. val_files: ~/data/rlhf/gsm8k/test.parquet +# Maximum sample length to be used. +# Set to -1 to use full dataset, otherwise, randomly +# select the specified number of samples from train dataset +train_max_samples: -1 + +# Maximum sample length to be used. +# Set to -1 to use full dataset, otherwise, randomly +# select the specified number of samples from val dataset +val_max_samples: -1 + # The field in the dataset where the prompt is located. Default is 'prompt'. prompt_key: prompt diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml index bb946be88ab..3d58266440c 100644 --- a/verl/trainer/config/sft_trainer.yaml +++ b/verl/trainer/config/sft_trainer.yaml @@ -4,6 +4,8 @@ data: micro_batch_size_per_gpu: 4 # this is also val batch size train_files: ~/data/gsm8k/train.parquet val_files: ~/data/gsm8k/test.parquet + train_max_samples: -1 # set to -1 to use full dataset + val_max_samples: -1 # set to -1 to use full dataset # Single-turn settings prompt_key: question response_key: answer diff --git a/verl/trainer/config/sft_trainer_engine.yaml b/verl/trainer/config/sft_trainer_engine.yaml index 1579f582338..0f7491d5f9d 100644 --- a/verl/trainer/config/sft_trainer_engine.yaml +++ b/verl/trainer/config/sft_trainer_engine.yaml @@ -19,6 +19,8 @@ data: use_dynamic_bsz: True train_files: ~/data/gsm8k/train.parquet val_files: null + train_max_samples: -1 # set to -1 to use full dataset + val_max_samples: -1 # set to -1 to use full dataset # Multi-turn settings messages_key: messages # Key for messages list in multi-turn mode tools_key: tools # Key for tools list in multi-turn mode diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index f8de9339cd4..d935f448ef6 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -800,8 +800,12 @@ def run_sft(config): local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) - train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) - val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) + train_dataset = create_sft_dataset( + config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1) + ) + val_dataset = create_sft_dataset( + config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1) + ) trainer = FSDPSFTTrainer( config=config, @@ -822,7 +826,7 @@ def main(config): run_sft(config) -def create_sft_dataset(data_paths, data_config, tokenizer): +def create_sft_dataset(data_paths, data_config, tokenizer, max_samples=-1): """Create a dataset.""" # build dataset # First check if a custom dataset class is specified @@ -838,7 +842,7 @@ def create_sft_dataset(data_paths, data_config, tokenizer): dataset_cls = SFTDataset # Create datasets based on the selected class - dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config) + dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config, max_samples=max_samples) return dataset diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index a6010f8950f..a80c9b81def 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -295,8 +295,22 @@ def run(self, config): from verl.utils.dataset.rl_dataset import collate_fn # Create training and validation datasets. - train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) - val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + is_train=True, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + is_train=False, + max_samples=config.data.get("val_max_samples", -1), + ) train_sampler = create_rl_sampler(config.data, train_dataset) # Initialize the PPO trainer. @@ -321,7 +335,7 @@ def run(self, config): trainer.fit() -def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True): +def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True, max_samples: int = -1): """Create a dataset. Arguments: @@ -365,6 +379,7 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr tokenizer=tokenizer, processor=processor, config=data_config, + max_samples=max_samples, ) return dataset diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 02ac4d67185..4d5a61d3a39 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -350,11 +350,19 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl if train_dataset is None: train_dataset = create_rl_dataset( - self.config.data.train_files, self.config.data, self.tokenizer, self.processor + self.config.data.train_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("train_max_samples", -1), ) if val_dataset is None: val_dataset = create_rl_dataset( - self.config.data.val_files, self.config.data, self.tokenizer, self.processor + self.config.data.val_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("val_max_samples", -1), ) self.train_dataset, self.val_dataset = train_dataset, val_dataset diff --git a/verl/trainer/sft_trainer.py b/verl/trainer/sft_trainer.py index 1fa3bdee1e4..67581232524 100644 --- a/verl/trainer/sft_trainer.py +++ b/verl/trainer/sft_trainer.py @@ -145,9 +145,13 @@ def _init_engine(self): def _build_dataset(self): config = self.config tokenizer = self.model_config.tokenizer - train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) + train_dataset = create_sft_dataset( + config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1) + ) if config.data.val_files: - val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) + val_dataset = create_sft_dataset( + config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1) + ) else: val_dataset = None @@ -372,7 +376,7 @@ def main(config): run_sft(config) -def create_sft_dataset(data_paths, data_config, tokenizer): +def create_sft_dataset(data_paths, data_config, tokenizer, max_samples=-1): """Create a dataset.""" # build dataset # First check if a custom dataset class is specified @@ -385,7 +389,7 @@ def create_sft_dataset(data_paths, data_config, tokenizer): dataset_cls = MultiTurnSFTDataset # Create datasets based on the selected class - dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config) + dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config, max_samples=max_samples) return dataset diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index 58583c6a853..f3aad961dca 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -49,7 +49,7 @@ class MultiTurnSFTDataset(Dataset): Dataset for multi-turn conversations where each assistant response should be trained """ - def __init__(self, parquet_files: str | list[str], tokenizer, config=None): + def __init__(self, parquet_files: str | list[str], tokenizer, config=None, max_samples: int = -1): # Set defaults and extract parameters from config if provided config = config or {} self.pad_mode = config.get("pad_mode", "right") @@ -65,6 +65,9 @@ def __init__(self, parquet_files: str | list[str], tokenizer, config=None): self.tools_key = multiturn_config.get("tools_key", "tools") self.enable_thinking_key = multiturn_config.get("enable_thinking_key", "enable_thinking") self.apply_chat_template_kwargs = config.get("apply_chat_template_kwargs", {}) + self.shuffle = config.get("shuffle", False) + self.seed = config.get("seed") + self.max_samples = max_samples assert self.truncation in ["error", "left", "right"] if not isinstance(parquet_files, list | ListConfig): @@ -97,6 +100,19 @@ def series_to_item(ls): dataframes.append(dataframe) self.dataframe = pd.concat(dataframes) + total = len(self.dataframe) + print(f"dataset len: {len(self.dataframe)}") + + if self.max_samples > 0 and self.max_samples < total: + if self.shuffle: + rngs_args = (self.seed,) if self.seed is not None else () + rng = np.random.default_rng(*rngs_args) + indices = rng.choice(total, size=self.max_samples, replace=False) + else: + indices = np.arange(self.max_samples) + self.dataframe = self.dataframe.iloc[indices.tolist()] + print(f"selected {self.max_samples} random samples out of {total}") + # Extract messages list from dataframe self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist() diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 63d1a3f2735..63bdb099feb 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -88,6 +88,7 @@ def __init__( tokenizer: PreTrainedTokenizer, config: DictConfig, processor: Optional[ProcessorMixin] = None, + max_samples: int = -1, ): if not isinstance(data_files, list | ListConfig): data_files = [data_files] @@ -96,6 +97,7 @@ def __init__( self.original_data_files = copy.deepcopy(data_files) # use for resume self.tokenizer = tokenizer self.processor = processor + self.max_samples = max_samples self.config = config self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) @@ -117,6 +119,8 @@ def __init__( self.filter_prompts = config.get("filter_prompts", True) self.serialize_dataset = False self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True) + self.shuffle = config.get("shuffle", False) + self.seed = config.get("seed") self._download() self._read_files_and_tokenize() @@ -136,8 +140,19 @@ def _read_files_and_tokenize(self): dataframes.append(dataframe) self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) + total = len(self.dataframe) print(f"dataset len: {len(self.dataframe)}") + if self.max_samples > 0 and self.max_samples < total: + if self.shuffle: + rngs_args = (self.seed,) if self.seed is not None else () + rng = np.random.default_rng(*rngs_args) + indices = rng.choice(total, size=self.max_samples, replace=False) + else: + indices = np.arange(self.max_samples) + self.dataframe = self.dataframe.select(indices.tolist()) + print(f"selected {self.max_samples} random samples out of {total}") + self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe) def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None): diff --git a/verl/utils/dataset/rm_dataset.py b/verl/utils/dataset/rm_dataset.py index 7af792343cb..e55308e3584 100644 --- a/verl/utils/dataset/rm_dataset.py +++ b/verl/utils/dataset/rm_dataset.py @@ -13,7 +13,9 @@ # limitations under the License. import os +from typing import Optional +import numpy as np import pandas as pd import torch from torch.utils.data import Dataset @@ -46,11 +48,17 @@ def __init__( max_length=1024, add_eos=True, cache_dir="~/.cache/verl/rm", + max_samples: int = -1, + shuffle: bool = False, + seed: Optional[int] = None, ): if not isinstance(parquet_files, list): parquet_files = [parquet_files] self.parquet_files = parquet_files + self.max_samples = max_samples + self.shuffle = shuffle + self.seed = seed self.cache_dir = os.path.expanduser(cache_dir) if isinstance(tokenizer, str): tokenizer = hf_tokenizer(tokenizer) @@ -88,6 +96,20 @@ def _read_files_and_tokenize(self): dataframe = pd.read_parquet(parquet_file) dataframes.append(dataframe) self.dataframe = pd.concat(dataframes) + + total = len(self.dataframe) + print(f"dataset len: {len(self.dataframe)}") + + if self.max_samples > 0 and self.max_samples < total: + if self.shuffle: + rngs_args = (self.seed,) if self.seed is not None else () + rng = np.random.default_rng(*rngs_args) + indices = rng.choice(total, size=self.max_samples, replace=False) + else: + indices = np.arange(self.max_samples) + self.dataframe = self.dataframe.iloc[indices.tolist()] + print(f"selected {self.max_samples} random samples out of {total}") + self.prompts = self.dataframe[self.prompt_key].tolist() self.chosen_responses = self.dataframe[self.chosen_key].tolist() self.rejected_responses = self.dataframe[self.rejected_key].tolist() diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py index 3bd4d751315..5fa8e07b252 100644 --- a/verl/utils/dataset/sft_dataset.py +++ b/verl/utils/dataset/sft_dataset.py @@ -18,6 +18,7 @@ Each parquet file contains """ +import numpy as np import pandas as pd import torch from omegaconf.listconfig import ListConfig @@ -37,7 +38,7 @@ class SFTDataset(Dataset): config (OmegaConf): the data config """ - def __init__(self, parquet_files: str | ListConfig, tokenizer, config): + def __init__(self, parquet_files: str | ListConfig, tokenizer, config, max_samples: int = -1): prompt_key = config.get("prompt_key", "prompt") prompt_dict_keys = config.get("prompt_dict_keys", None) response_key = config.get("response_key", "response") @@ -45,6 +46,8 @@ def __init__(self, parquet_files: str | ListConfig, tokenizer, config): max_length = config.get("max_length", 1024) truncation = config.get("truncation", "error") use_shm = config.get("use_shm", False) + self.shuffle = config.get("shuffle", False) + self.seed = config.get("seed") self.apply_chat_template_kwargs = config.get("apply_chat_template_kwargs", {}) assert truncation in ["error", "left", "right"] @@ -55,6 +58,7 @@ def __init__(self, parquet_files: str | ListConfig, tokenizer, config): parquet_files = [parquet_files] self.parquet_files = parquet_files + self.max_samples = max_samples if isinstance(tokenizer, str): tokenizer = hf_tokenizer(tokenizer) self.tokenizer: PreTrainedTokenizer = tokenizer @@ -88,6 +92,20 @@ def series_to_item(ls): dataframe = pd.read_parquet(parquet_file) dataframes.append(dataframe) self.dataframe = pd.concat(dataframes) + + total = len(self.dataframe) + print(f"dataset len: {len(self.dataframe)}") + + if self.max_samples > 0 and self.max_samples < total: + if self.shuffle: + rngs_args = (self.seed,) if self.seed is not None else () + rng = np.random.default_rng(*rngs_args) + indices = rng.choice(total, size=self.max_samples, replace=False) + else: + indices = np.arange(self.max_samples) + self.dataframe = self.dataframe.iloc[indices.tolist()] + print(f"selected {self.max_samples} random samples out of {total}") + self.prompts = self.dataframe[self.prompt_key] for key in self.prompt_dict_keys: # type(x): pandas.core.series.Series