Skip to content

Commit e7fc3e3

Browse files
committed
fix(accelerate_ppo_trainer): no resizing when using peft reference
1 parent 923ec65 commit e7fc3e3

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

trlx/trainer/accelerate_ppo_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,15 @@ def __init__(self, config: TRLConfig, **kwargs):
7171

7272
# Set up a reference model when hydra heads are not used
7373
if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
74+
# Full Reference Copy
7475
self.ref_model = self.get_arch(self.config)
7576
self.ref_model.base_model.resize_token_embeddings(len(self.tokenizer))
7677
self.ref_model.to(self.accelerator.device)
7778
self.ref_model.eval()
78-
else:
79-
# resize hydra heads
79+
elif hasattr(self.model, "frozen_head"):
80+
# Hydra Reference: Use the frozen base layers and head as the reference model, resize hydra heads
8081
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer))
82+
# TODO: else PEFT Reference, do something?
8183

8284
# Set up the KL controller
8385
# This helps prevent large divergences in the controller (policy)

0 commit comments

Comments
 (0)