Skip to content
4 changes: 3 additions & 1 deletion nemo_automodel/components/config/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,13 @@ def instantiate(self, *args, **kwargs):
"Instantiation failed for `{}`\n"
"Accepted signature : {}\n"
"Positional args : {}\n"
"Keyword args : {}\n".format(
"Keyword args : {}\n"
"Exception : {}\n".format(
func.__name__,
sig,
args,
pprint.pformat(config_kwargs, compact=True, indent=4),
e,
),
file=sys.stderr,
)
Expand Down
6 changes: 6 additions & 0 deletions nemo_automodel/components/datasets/llm/chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def __init__(
split: Optional[str] = None,
name: Optional[str] = None,
seq_length: Optional[int] = None,
padding: Union[str, bool] = "do_not_pad",
truncation: Union[str, bool] = "do_not_truncate",
start_of_turn_token: Optional[str] = None,
chat_template: Optional[str] = None,
) -> None:
Expand All @@ -149,6 +151,8 @@ def __init__(

self.tokenizer = tokenizer
self.seq_length = seq_length
self.padding = padding
self.truncation = truncation
self.start_of_turn_token = start_of_turn_token

self.dataset = _load_openai_messages(path_or_dataset_id, split=split, name=name)
Expand Down Expand Up @@ -178,6 +182,8 @@ def __getitem__(self, idx: int) -> Dict[str, List[int]]:
eos_token_id,
self.pad_token_id,
seq_length=self.seq_length,
padding=self.padding,
truncation=self.truncation,
tools=tools,
)
return sample
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
format_prompt_completion,
)

logger = logging.getLogger(__name__)

# Supported cases:
# Format:
# - Context + question + answer
Expand Down Expand Up @@ -165,6 +167,8 @@ def __init__(
name: Optional[str] = None,
answer_only_loss_mask: bool = True,
seq_length: Optional[int] = None,
padding: Union[str, bool] = "do_not_pad",
truncation: Union[str, bool] = "do_not_truncate",
start_of_turn_token: Optional[str] = None,
limit_dataset_samples: Optional[int] = None,
) -> None:
Expand Down Expand Up @@ -193,6 +197,12 @@ def __init__(

assert tokenizer is not None, "Tokenizer is required"
self.tokenizer = tokenizer
if getattr(self.tokenizer, "pad_token", None) is None:
if hasattr(self.tokenizer, "eos_token"):
self.tokenizer.pad_token = self.tokenizer
else:
logger.warning("Setting tokenizer pad_token to ' '. tokenizer does not have `eos_token`.")
self.tokenizer.pad_token = " "

self.dataset = _load_dataset(path_or_dataset_id, split=split, streaming=False, name=name)

Expand Down Expand Up @@ -226,6 +236,8 @@ def __init__(
self.answer_only_loss_mask = answer_only_loss_mask
self.start_of_turn_token = start_of_turn_token
self.seq_length = seq_length
self.padding = padding
self.truncation = truncation

def __len__(self) -> int: # noqa: D401
"""
Expand Down Expand Up @@ -255,6 +267,8 @@ def __getitem__(self, idx): # noqa: D401
row = self.dataset[idx]
mapped = {dest: row[src] for dest, src in self.column_mapping.items() if src in row}
mapped = self._apply_tokenizer(mapped)
if not any(label != -100 for label in mapped["labels"]):
return self.__getitem__((idx + 1) % len(self.dataset))
assert _check_all_values_equal_length(mapped), "All values must be of the same length"
return mapped

Expand Down Expand Up @@ -293,6 +307,8 @@ def _apply_tokenizer(self, sample: Dict[str, str]) -> Dict[str, List[int]]:
eos_token_id,
pad_token_id,
seq_length=self.seq_length,
padding=self.padding,
truncation=self.truncation,
)
else:
prompt = " ".join(filter(lambda x: x is not None, (context, question, "")))
Expand All @@ -304,5 +320,7 @@ def _apply_tokenizer(self, sample: Dict[str, str]) -> Dict[str, List[int]]:
eos_token_id,
pad_token_id,
seq_length=self.seq_length,
padding=self.padding,
truncation=self.truncation,
answer_only_loss_mask=self.answer_only_loss_mask,
)
19 changes: 15 additions & 4 deletions nemo_automodel/components/datasets/llm/formatting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import re
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Union

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,6 +66,7 @@ def _package_tokenized_example(
eos_token_id,
pad_token_id,
seq_length,
truncation=None,
):
"""
Package a tokenized example with proper masking and padding.
Expand All @@ -77,7 +78,7 @@ def _package_tokenized_example(
eos_token_id: The end-of-sequence token id.
pad_token_id: The padding token id.
seq_length: Optional sequence length for padding.

truncation: Optional truncation strategy.
Returns:
A dictionary with input_ids, labels, and attention_mask.
"""
Expand All @@ -86,6 +87,8 @@ def _package_tokenized_example(
if not _has_chat_template(tokenizer) and eos_token_id != input_ids[-1]:
input_ids += [eos_token_id]
assistant_masks += [1]
if not _has_chat_template(tokenizer) and pad_token_id is not None:
assistant_masks += [pad_token_id]

labels = input_ids.copy()
input_ids = input_ids[:-1]
Expand All @@ -95,7 +98,7 @@ def _package_tokenized_example(
labels[:] = [label if bool(m) else -100 for label, m in zip(labels, assistant_masks)]
# remove BOS
labels = labels[1:]
if not _has_chat_template(tokenizer):
if not _has_chat_template(tokenizer) and truncation is None:
assert labels[-1] == eos_token_id, f"labels[-1]={labels[-1]} != eos_token_id={eos_token_id}"
assert input_ids[-1] != eos_token_id, f"input_ids[-1]={input_ids[-1]} == eos_token_id={eos_token_id}"
assert len(input_ids) == len(labels), f"len(input_ids)={len(input_ids)} != len(labels)={len(labels)}"
Expand Down Expand Up @@ -125,6 +128,8 @@ def format_prompt_completion(
eos_token_id: int,
pad_token_id: int,
seq_length: Optional[int] = None,
padding: Union[str, bool] = "do_not_pad",
truncation: Union[str, bool] = "do_not_truncate",
answer_only_loss_mask: bool = True,
) -> Dict[str, List[int]]:
"""
Expand All @@ -150,7 +155,7 @@ def format_prompt_completion(
else:
len_prompt_ids = 0
# Tokenize full text
input_ids = tokenizer(full_text)["input_ids"]
input_ids = tokenizer(full_text, padding=padding, truncation=truncation, max_length=seq_length)["input_ids"]

# Create assistant_masks: 0 for prompt tokens, 1 for answer tokens
assistant_masks = [0] * len_prompt_ids + [1] * (len(input_ids) - len_prompt_ids)
Expand All @@ -162,6 +167,7 @@ def format_prompt_completion(
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
seq_length=seq_length,
truncation=truncation,
)


Expand All @@ -171,6 +177,8 @@ def format_chat_template(
eos_token_id: int,
pad_token_id: int,
seq_length: Optional[int] = None,
padding: Union[str, bool] = "do_not_pad",
truncation: Union[str, bool] = "do_not_truncate",
tools: Optional[List[Dict]] = None,
) -> Dict[str, List[int]]:
"""
Expand Down Expand Up @@ -199,6 +207,9 @@ def format_chat_template(
tokenize=True,
return_dict=True,
return_assistant_tokens_mask=template_has_generation_kwd,
padding=padding,
truncation=truncation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly we truncate regardless of the template structure here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

max_length=seq_length,
)

# Choose the last conversation as answer other history are context by finding the last masked token
Expand Down
Loading
Loading