diff --git a/P3-SAM/model.py b/P3-SAM/model.py index b09aaa1..bdddb95 100644 --- a/P3-SAM/model.py +++ b/P3-SAM/model.py @@ -1,7 +1,8 @@ import os import sys -import torch -import torch.nn as nn +import torch +import torch.nn as nn +from safetensors.torch import load_file as load_safetensors sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'XPart/partgen')) from models import sonata from utils.misc import smart_load_model @@ -15,7 +16,7 @@ ''' def build_P3SAM(self): #build p3sam ######################## Sonata ######################## - self.sonata = sonata.load("sonata", repo_id="facebook/sonata", download_root='/root/sonata') + self.sonata = sonata.load("sonata", repo_id="facebook/sonata", download_root='weights/sonata/') self.mlp = nn.Sequential( nn.Linear(1232, 512), nn.GELU(), @@ -113,14 +114,17 @@ def load_state_dict(self, ignore_seg_s2_mlp=False, ignore_iou_mlp=False): # load checkpoint if ckpt_path is not None: - state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + if ckpt_path.endswith('.safetensors'): + state_dict = load_safetensors(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] elif state_dict is None: # download from huggingface print(f'trying to download model from huggingface...') from huggingface_hub import hf_hub_download - ckpt_path = hf_hub_download(repo_id="tencent/Hunyuan3D-Part", filename="p3sam.ckpt", local_dir='weights') + ckpt_path = hf_hub_download(repo_id="tencent/Hunyuan3D-Part", filename="p3sam/p3sam.safetensors", local_dir='weights') print(f'download model from huggingface to: {ckpt_path}') - state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + state_dict = load_safetensors(ckpt_path) local_state_dict = self.state_dict() seen_keys = {k: False for k in local_state_dict.keys()}