Skip to content

Commit db24a57

Browse files
committed
merge update
2 parents c8e20ea + 662f318 commit db24a57

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

fast_llm/data/preparator/gpt_memmap/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class SourceSchemaConfig(Config):
3030

3131
@config_class()
3232
class PromptCompletionConfig(SourceSchemaConfig):
33-
type: typing.ClassVar[str] = "prompt_completion"
33+
type: typing.ClassVar[str] = "prompt_completion" #TODO: Register PromptCompletionConfig for this type for dynamic loading PR #245
3434
prompt_column: str = Field(
3535
default="prompt",
3636
desc="Field of the dataset to use.",
@@ -49,7 +49,7 @@ class PromptCompletionConfig(SourceSchemaConfig):
4949

5050
@config_class()
5151
class TextColumnConfig(SourceSchemaConfig):
52-
type: typing.ClassVar[str] = "text_column"
52+
type: typing.ClassVar[str] = "text_column" #TODO: Register TestColumnConfig for this type for dynamic loading PR #245
5353
input_column: str = Field(
5454
default="text",
5555
desc="Field of the dataset to use.",
@@ -88,7 +88,7 @@ class GPTHuggingfaceDatasetConfig(Config):
8888
hint=FieldHint.optional,
8989
)
9090
source_schema: SourceSchemaConfig = Field(
91-
default_factory=TextColumnConfig,
91+
#TODO: Default should be from subclass TextColumnConfig (waiting for PR #245)
9292
desc="Configuration for the data source.",
9393
hint=FieldHint.optional,
9494
)

fast_llm/data/preparator/gpt_memmap/prepare.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D
3737

3838
_tokenizer: Tokenizer
3939
_data_type: DataType
40-
_data_column: str
40+
_text_column: str
4141
_loss_masking_spans_column: str | None
4242

4343
def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]:
4444
input_ids = [
4545
np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy)
46-
for text in batch[self._data_column]
46+
for text in batch[self._text_column]
4747
]
4848
num_tokens = [len(x) for x in input_ids]
4949
return {
@@ -63,7 +63,7 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict
6363
for input_ids, token_spans in [
6464
self._tokenizer.tokenize_with_spans(text, char_spans)
6565
for text, char_spans in zip(
66-
batch[self._data_column], batch[self._loss_masking_spans_column]
66+
batch[self._text_column], batch[self._loss_masking_spans_column]
6767
)
6868
]
6969
]
@@ -254,8 +254,8 @@ def run(self) -> None:
254254
num_shards=self._config.distributed.world_size,
255255
index=self._config.distributed.rank,
256256
)
257-
if self._data_column not in dataset.column_names:
258-
raise ValueError(f"Dataset does not have field '{self._data_column}'.")
257+
if self._text_column not in dataset.column_names:
258+
raise ValueError(f"Dataset does not have field '{self._text_column}'.")
259259
if self._loss_masking_spans_column is not None:
260260
if self._loss_masking_spans_column not in dataset.column_names:
261261
raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.")

0 commit comments

Comments
 (0)