Skip to content

Commit bc0cb30

Browse files
committed
include loss masking span always available
1 parent 4c84d20 commit bc0cb30

File tree

2 files changed

+13
-19
lines changed

2 files changed

+13
-19
lines changed

fast_llm/data/preparator/gpt_memmap/config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,6 @@ class PromptCompletionConfig(SourceSchemaConfig):
4646
desc="Delimiter between prompt and completion.",
4747
hint=FieldHint.optional,
4848
)
49-
set_loss_masking_for_prompt: bool = Field(
50-
default=False,
51-
desc="Create loss mask spans based on prompt.",
52-
hint=FieldHint.optional,
53-
)
5449

5550
@config_class()
5651
class TextColumnConfig(SourceSchemaConfig):

fast_llm/data/preparator/gpt_memmap/prepare.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -232,20 +232,19 @@ def run(self) -> None:
232232
)
233233
logger.info(f"Sample after combining fields:\n{dataset[0]}")
234234
self._data_column = new_combined_column
235-
236-
if source_schema.set_loss_masking_for_prompt:
237-
loss_masking_column = f"{source_schema.prompt_columns}_loss_masking_spans"
238-
dataset = dataset.map(
239-
lambda example: {
240-
loss_masking_column: [
241-
(0, len(str(example[source_schema.prompt_column])) - 1)
242-
]# spans are inclusive
243-
},
244-
batched=False,
245-
desc="Setting loss masking spans",
246-
)
247-
logger.info(f"Sample after setting loss masking spans:\n{dataset[0]}")
248-
self._loss_masking_spans_column = loss_masking_column
235+
# Set loss masking spans (handle chat template prior to this)
236+
loss_masking_column = f"{source_schema.prompt_columns}_loss_masking_spans"
237+
dataset = dataset.map(
238+
lambda example: {
239+
loss_masking_column: [
240+
(0, len(str(example[source_schema.prompt_column])) - 1)
241+
]# spans are inclusive
242+
},
243+
batched=False,
244+
desc="Setting loss masking spans",
245+
)
246+
logger.info(f"Sample after setting loss masking spans:\n{dataset[0]}")
247+
self._loss_masking_spans_column = loss_masking_column
249248
else:
250249
raise ValueError(f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.data_source}'.")
251250

0 commit comments

Comments
 (0)