diff --git a/P3-SAM/model.py b/P3-SAM/model.py index b09aaa1..1082b1d 100644 --- a/P3-SAM/model.py +++ b/P3-SAM/model.py @@ -113,7 +113,11 @@ def load_state_dict(self, ignore_seg_s2_mlp=False, ignore_iou_mlp=False): # load checkpoint if ckpt_path is not None: - state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + if ckpt_path.endswith('.safetensors'): + from safetensors.torch import load_file + state_dict = load_file(ckpt_path, device="cpu") + else: + state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] elif state_dict is None: # download from huggingface print(f'trying to download model from huggingface...')