Skip to content

Commit 160a0b8

Browse files
committed
tmp
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent a02e93a commit 160a0b8

File tree

2 files changed

+109
-6
lines changed

2 files changed

+109
-6
lines changed

nemo_automodel/components/datasets/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,3 +451,33 @@ def process(self, raw_dataset, ds):
451451
)
452452

453453
return tokenized
454+
455+
456+
def seq_cls_collater(batch):
457+
"""
458+
Collate function for sequence classification.
459+
460+
Expects each item in batch to be a dict with keys:
461+
- "input_ids": List[int]
462+
- "attention_mask": List[int]
463+
- "labels": int
464+
465+
Returns a dict with tensors:
466+
- input_ids: LongTensor [batch, seq_len]
467+
- attention_mask: LongTensor [batch, seq_len]
468+
- labels: LongTensor [batch]
469+
"""
470+
# Extract and stack sequences; assume they are pre-padded to uniform length
471+
input_ids = [sample["input_ids"] for sample in batch]
472+
attention_mask = [sample.get("attention_mask", [1] * len(sample["input_ids"])) for sample in batch]
473+
labels = [int(sample["labels"]) for sample in batch]
474+
475+
input_ids_tensor = torch.LongTensor(input_ids)
476+
attention_mask_tensor = torch.LongTensor(attention_mask)
477+
labels_tensor = torch.LongTensor(labels)
478+
479+
return {
480+
"input_ids": input_ids_tensor,
481+
"attention_mask": attention_mask_tensor,
482+
"labels": labels_tensor,
483+
}

nemo_automodel/recipes/llm/train_seq_cls.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ def setup(self):
9494
)
9595

9696
model, model_state_dict_keys, self.optimizer, _ = build_model_and_optimizer(
97-
self.dist_env.device,
98-
self.cfg.model,
99-
self.cfg.optimizer,
100-
use_hf_fa2,
101-
None,
102-
self.model_wrapper,
97+
device=self.dist_env.device,
98+
cfg_model=self.cfg.model,
99+
cfg_opt=self.cfg.optimizer,
100+
cfg_peft=None,
101+
has_packed_sequence=use_hf_fa2,
102+
model_wrapper=self.model_wrapper,
103103
seed=self.cfg.get("seed", 42),
104104
tp_size=self.cfg.get("distributed.tp_size", 1),
105105
cp_size=self.cfg.get("distributed.cp_size", 1),
@@ -269,6 +269,79 @@ def _validate_one_epoch(self, dataloader):
269269
return total_loss
270270

271271

272+
def log_val_metrics(self, log_data):
273+
"""Log metrics to wandb and other loggers
274+
Args:
275+
log_data: MetricsSample object, containing:
276+
step: int, the current step.
277+
epoch: int, the current epoch.
278+
metrics: Dict[str, float], containing:
279+
"val_loss": Validation loss.
280+
"lr": Learning rate.
281+
"num_label_tokens": Number of label tokens.
282+
"mem": Memory allocated.
283+
"""
284+
285+
# Pipeline parallelism does not support validation -> log_data is None
286+
if not self.dist_env.is_main or log_data is None:
287+
return
288+
289+
# if wandb.run is not None:
290+
# wandb.log(log_data.to_dict(), step=log_data.step)
291+
292+
# JSONL validation log
293+
self.metric_logger_valid.log(log_data)
294+
295+
logging.info(
296+
"[val] step {} | epoch {} | loss {:.4f} | lr {:.2e} | num_label_tokens {}".format(
297+
log_data.step,
298+
log_data.epoch,
299+
log_data.metrics["val_loss"],
300+
log_data.metrics["lr"],
301+
log_data.metrics["num_label_tokens"],
302+
)
303+
)
304+
305+
def log_train_metrics(self, log_data):
306+
"""Log metrics to wandb and other loggers.
307+
308+
Args:
309+
log_data: MetricsSample object, containing:
310+
step: int, the current step.
311+
epoch: int, the current epoch.
312+
metrics: Dict[str, float], containing:
313+
"loss": Training loss.
314+
"grad_norm": Grad norm from the training step.
315+
"lr": Learning rate.
316+
"mem": Memory allocated.
317+
"tps": Tokens per second.
318+
"tps_per_gpu": Tokens per second per GPU.
319+
"num_label_tokens": Number of label tokens.
320+
"""
321+
if not self.dist_env.is_main:
322+
return
323+
324+
# if wandb.run is not None:
325+
# wandb.log(log_data.to_dict(), step=self.step_scheduler.step)
326+
# JSONL training log
327+
self.metric_logger_train.log(log_data)
328+
logging.info(
329+
"step {} | epoch {} | loss {:.4f} | grad_norm {:.4f} | lr {:.2e} | mem {:.2f} GiB | tps {:.2f}({:.2f}/gpu) | num_label_tokens {}".format(
330+
log_data.step,
331+
log_data.epoch,
332+
log_data.metrics["loss"],
333+
0,
334+
# log_data.metrics["grad_norm"],
335+
log_data.metrics["lr"],
336+
log_data.metrics["mem"],
337+
log_data.metrics["tps"],
338+
log_data.metrics["tps_per_gpu"],
339+
log_data.metrics["num_label_tokens"],
340+
)
341+
)
342+
torch.cuda.reset_peak_memory_stats()
343+
344+
272345
def main(config_path: str | None = None):
273346
if config_path is None:
274347
config_path = (

0 commit comments

Comments
 (0)