@@ -145,17 +145,19 @@ def _sample(self) -> None:
145
145
raise RuntimeError (
146
146
f" > No documents shorter than { self ._parameters .sequence_length + 1 } tokens found in dataset { self ._indexed_dataset .name } ."
147
147
)
148
- # TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
149
- # We produce sequences of length `self._sequence_length + 1` so the last token has a label,
150
- # but in case of truncations we also include that last label in the following sample,
151
- # so we need `sequence_length * num_samples + 1` tokens in total.
152
- num_epochs = math .ceil (
153
- (
154
- (self ._parameters .sequence_length + 1 - self ._truncate_documents ) * self ._parameters .num_samples
155
- + 1 * self ._truncate_documents
148
+ # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads,
149
+ # but in case of truncations we also include those last labels in the following sample,
150
+ # so we need `sequence_length * num_samples + extra_tokens` tokens in total.
151
+ if self ._truncate_documents :
152
+ num_epochs = math .ceil (
153
+ (self ._parameters .sequence_length * self ._parameters .num_samples + self ._parameters .extra_tokens )
154
+ / tokens_per_epoch
155
+ )
156
+ else :
157
+ num_epochs = math .ceil (
158
+ ((self ._parameters .sequence_length + self ._parameters .extra_tokens ) * self ._parameters .num_samples )
159
+ / tokens_per_epoch
156
160
)
157
- / tokens_per_epoch
158
- )
159
161
160
162
# Prepare for shuffling.
161
163
generator = torch .Generator (device = self ._device )
@@ -349,8 +351,13 @@ def __getitem__(self, index: int) -> typing.Any:
349
351
self ._lazy_load ()
350
352
# tokens at the boundary are included in only one sample when we pack without truncations
351
353
# in case of packing with truncations, the last token from the previous sample is also the first token of the next sample
352
- token_start = index * (self ._parameters .sequence_length + 1 - self ._truncate_documents )
353
- token_end = token_start + self ._parameters .sequence_length + 1
354
+ sample_length = (
355
+ self ._parameters .sequence_length
356
+ if self ._truncate_documents
357
+ else self ._parameters .sequence_length + self ._parameters .extra_tokens
358
+ )
359
+ token_start = index * sample_length
360
+ token_end = token_start + self ._parameters .sequence_length + self ._parameters .extra_tokens
354
361
355
362
if token_start < self ._unshuffled_tokens :
356
363
token_start_array = self ._token_cumsum_unshuffled .array
@@ -410,7 +417,9 @@ def __getitem__(self, index: int) -> typing.Any:
410
417
if self ._parameters .use_loss_masking_spans :
411
418
for loss_masking_span in sample .loss_masking_spans :
412
419
span = np .clip (
413
- loss_masking_span + token_count - token_start , 0 , self ._parameters .sequence_length + 1
420
+ loss_masking_span + token_count - token_start ,
421
+ 0 ,
422
+ self ._parameters .sequence_length + self ._parameters .extra_tokens ,
414
423
)
415
424
if span [1 ] > span [0 ]:
416
425
loss_masking_spans .append (span )
@@ -430,7 +439,7 @@ def __getitem__(self, index: int) -> typing.Any:
430
439
if self ._parameters .use_loss_masking_spans
431
440
else None
432
441
)
433
- Assert .eq (len (token_ids ), self ._parameters .sequence_length + 1 )
442
+ Assert .eq (len (token_ids ), self ._parameters .sequence_length + self . _parameters . extra_tokens )
434
443
435
444
return GPTSample (token_ids = token_ids , loss_masking_spans = loss_masking_spans , sequence_lengths = sequence_lengths )
436
445
0 commit comments