diff --git a/peft/interpret/interpret.py b/peft/interpret/interpret.py index 1492c60..bb0028b 100644 --- a/peft/interpret/interpret.py +++ b/peft/interpret/interpret.py @@ -1,7 +1,7 @@ import json from transformers import AutoModelForCausalLM -from peft import get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType +from peft import PeftModel, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType import torch from numpy import array @@ -10,9 +10,7 @@ def load_model(): - # NOTE: Currently loading untrained model, will push changes for trained model loading model_name_or_path = "meta-llama/Llama-2-7b-chat-hf" - peft_config = PromptTuningConfig( task_type=TaskType.CAUSAL_LM, prompt_tuning_init=PromptTuningInit.TEXT, @@ -22,7 +20,8 @@ def load_model(): tokenizer_name_or_path=model_name_or_path, ) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='cpu', torch_dtype=torch.float16) - model = get_peft_model(model, peft_config) # add PEFT pieces to the LLM + #model = get_peft_model(model, peft_config) # add PEFT pieces to the LLM + model = PeftModel.from_pretrained(model, 'meyceoz/prompt-llama-2') return model