@@ -37,13 +37,13 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D
37
37
38
38
_tokenizer : Tokenizer
39
39
_data_type : DataType
40
- _data_column : str
40
+ _text_column : str
41
41
_loss_masking_spans_column : str | None
42
42
43
43
def _tokenize_batch (self , batch : dict [str , list [typing .Any ]]) -> dict [str , list [typing .Any ]]:
44
44
input_ids = [
45
45
np .array (self ._tokenizer .tokenize (text ), dtype = self ._data_type .numpy )
46
- for text in batch [self ._data_column ]
46
+ for text in batch [self ._text_column ]
47
47
]
48
48
num_tokens = [len (x ) for x in input_ids ]
49
49
return {
@@ -63,7 +63,7 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict
63
63
for input_ids , token_spans in [
64
64
self ._tokenizer .tokenize_with_spans (text , char_spans )
65
65
for text , char_spans in zip (
66
- batch [self ._data_column ], batch [self ._loss_masking_spans_column ]
66
+ batch [self ._text_column ], batch [self ._loss_masking_spans_column ]
67
67
)
68
68
]
69
69
]
@@ -254,8 +254,8 @@ def run(self) -> None:
254
254
num_shards = self ._config .distributed .world_size ,
255
255
index = self ._config .distributed .rank ,
256
256
)
257
- if self ._data_column not in dataset .column_names :
258
- raise ValueError (f"Dataset does not have field '{ self ._data_column } '." )
257
+ if self ._text_column not in dataset .column_names :
258
+ raise ValueError (f"Dataset does not have field '{ self ._text_column } '." )
259
259
if self ._loss_masking_spans_column is not None :
260
260
if self ._loss_masking_spans_column not in dataset .column_names :
261
261
raise ValueError (f"Dataset does not have spans field '{ self ._loss_masking_spans_column } '." )
0 commit comments