@@ -94,12 +94,12 @@ def setup(self):
9494 )
9595
9696 model , model_state_dict_keys , self .optimizer , _ = build_model_and_optimizer (
97- self .dist_env .device ,
98- self .cfg .model ,
99- self .cfg .optimizer ,
100- use_hf_fa2 ,
101- None ,
102- self .model_wrapper ,
97+ device = self .dist_env .device ,
98+ cfg_model = self .cfg .model ,
99+ cfg_opt = self .cfg .optimizer ,
100+ cfg_peft = None ,
101+ has_packed_sequence = use_hf_fa2 ,
102+ model_wrapper = self .model_wrapper ,
103103 seed = self .cfg .get ("seed" , 42 ),
104104 tp_size = self .cfg .get ("distributed.tp_size" , 1 ),
105105 cp_size = self .cfg .get ("distributed.cp_size" , 1 ),
@@ -269,6 +269,79 @@ def _validate_one_epoch(self, dataloader):
269269 return total_loss
270270
271271
272+ def log_val_metrics (self , log_data ):
273+ """Log metrics to wandb and other loggers
274+ Args:
275+ log_data: MetricsSample object, containing:
276+ step: int, the current step.
277+ epoch: int, the current epoch.
278+ metrics: Dict[str, float], containing:
279+ "val_loss": Validation loss.
280+ "lr": Learning rate.
281+ "num_label_tokens": Number of label tokens.
282+ "mem": Memory allocated.
283+ """
284+
285+ # Pipeline parallelism does not support validation -> log_data is None
286+ if not self .dist_env .is_main or log_data is None :
287+ return
288+
289+ # if wandb.run is not None:
290+ # wandb.log(log_data.to_dict(), step=log_data.step)
291+
292+ # JSONL validation log
293+ self .metric_logger_valid .log (log_data )
294+
295+ logging .info (
296+ "[val] step {} | epoch {} | loss {:.4f} | lr {:.2e} | num_label_tokens {}" .format (
297+ log_data .step ,
298+ log_data .epoch ,
299+ log_data .metrics ["val_loss" ],
300+ log_data .metrics ["lr" ],
301+ log_data .metrics ["num_label_tokens" ],
302+ )
303+ )
304+
305+ def log_train_metrics (self , log_data ):
306+ """Log metrics to wandb and other loggers.
307+
308+ Args:
309+ log_data: MetricsSample object, containing:
310+ step: int, the current step.
311+ epoch: int, the current epoch.
312+ metrics: Dict[str, float], containing:
313+ "loss": Training loss.
314+ "grad_norm": Grad norm from the training step.
315+ "lr": Learning rate.
316+ "mem": Memory allocated.
317+ "tps": Tokens per second.
318+ "tps_per_gpu": Tokens per second per GPU.
319+ "num_label_tokens": Number of label tokens.
320+ """
321+ if not self .dist_env .is_main :
322+ return
323+
324+ # if wandb.run is not None:
325+ # wandb.log(log_data.to_dict(), step=self.step_scheduler.step)
326+ # JSONL training log
327+ self .metric_logger_train .log (log_data )
328+ logging .info (
329+ "step {} | epoch {} | loss {:.4f} | grad_norm {:.4f} | lr {:.2e} | mem {:.2f} GiB | tps {:.2f}({:.2f}/gpu) | num_label_tokens {}" .format (
330+ log_data .step ,
331+ log_data .epoch ,
332+ log_data .metrics ["loss" ],
333+ 0 ,
334+ # log_data.metrics["grad_norm"],
335+ log_data .metrics ["lr" ],
336+ log_data .metrics ["mem" ],
337+ log_data .metrics ["tps" ],
338+ log_data .metrics ["tps_per_gpu" ],
339+ log_data .metrics ["num_label_tokens" ],
340+ )
341+ )
342+ torch .cuda .reset_peak_memory_stats ()
343+
344+
272345def main (config_path : str | None = None ):
273346 if config_path is None :
274347 config_path = (
0 commit comments