diff --git a/large_language_models/alpaca-qlora/finetune_pp.py b/large_language_models/alpaca-qlora/finetune_pp.py index ba2492e..5030031 100644 --- a/large_language_models/alpaca-qlora/finetune_pp.py +++ b/large_language_models/alpaca-qlora/finetune_pp.py @@ -120,7 +120,7 @@ def tokenize(prompt): batch_size=args.micro_batch_size, sampler=RandomSampler(train_dataset, generator=generator), collate_fn=data_collator, - drop_last=False, + drop_last=True, num_workers=0, pin_memory=True, worker_init_fn=seed_worker, @@ -148,7 +148,7 @@ def tokenize(prompt): lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, - num_training_steps=len(epoch_iterator) // GRADIENT_ACCUMULATION_STEPS, + num_training_steps=EPOCHS * len(epoch_iterator) // GRADIENT_ACCUMULATION_STEPS, ) scaler = torch.cuda.amp.GradScaler() @@ -158,47 +158,48 @@ def tokenize(prompt): model.eval() model.train() - step = 0 - accumulated_loss = 0 - for inputs in tqdm(epoch_iterator): - step += 1 - with torch.cuda.amp.autocast(cache_enabled=True, dtype=torch.float16): - labels = inputs.pop("labels") - outputs = model(**inputs) - outputs["logits"] = outputs["logits"].float() - labels = labels.to(outputs["logits"].device) - shift_logits = outputs["logits"][..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, shift_logits.shape[-1]) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - loss /= GRADIENT_ACCUMULATION_STEPS - accumulated_loss += loss.item() - scaler.scale(loss).backward() - - if step % GRADIENT_ACCUMULATION_STEPS == 0: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - filter(lambda p: p.requires_grad, model.parameters()), max_norm=1.0 - ) - scaler.step(optimizer) - scaler.update() - lr_scheduler.step() - model.zero_grad() - tqdm.write( - "{" - + "'loss': {0:1.4f}, 'learning_rate': {1:2.6f}, 'epoch': {2:3.2f}".format( - accumulated_loss, - optimizer.param_groups[0]["lr"], - step / len(epoch_iterator), + for epoch in range(EPOCHS): + step = 0 + accumulated_loss = 0 + for inputs in tqdm(epoch_iterator): + step += 1 + with torch.cuda.amp.autocast(cache_enabled=True, dtype=torch.float16): + labels = inputs.pop("labels") + outputs = model(**inputs) + outputs["logits"] = outputs["logits"].float() + labels = labels.to(outputs["logits"].device) + shift_logits = outputs["logits"][..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, shift_logits.shape[-1]) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + loss /= GRADIENT_ACCUMULATION_STEPS + accumulated_loss += loss.item() + scaler.scale(loss).backward() + + if step % GRADIENT_ACCUMULATION_STEPS == 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + filter(lambda p: p.requires_grad, model.parameters()), max_norm=1.0 ) - + "}" - ) - accumulated_loss = 0 + scaler.step(optimizer) + scaler.update() + lr_scheduler.step() + model.zero_grad() + tqdm.write( + "{" + + "'loss': {0:1.4f}, 'learning_rate': {1:2.6f}, 'epoch': {2:3.2f}".format( + accumulated_loss, + optimizer.param_groups[0]["lr"], + step / len(epoch_iterator) + epoch, + ) + + "}" + ) + accumulated_loss = 0 print("Peak memory usage for GPUs: ", end="") for i in range(len(model.model.devices)):