Skip to content

Update test_in_custom.py #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions PromptPAR/test_in_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pprint
from collections import OrderedDict
import time
import json
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
Expand All @@ -16,6 +17,7 @@
from torchvision import transforms
from PIL import Image
from config import argument_parser
from utils.train_utils import AverageMeter
set_seed(605)
device = "cuda" if torch.cuda.is_available() else "cpu"
# attr_words的设置取决于你加载哪个数据集训练的checkpoint 例如PETA
Expand All @@ -26,13 +28,12 @@
'lower Casual', 'lower Formal', 'lower Jeans', 'lower Shorts', 'lower Short Skirt','lower Trousers',
'shoes Leather', 'shoes Sandals', 'shoes other', 'shoes sneaker',
'attach Backpack', 'attach Other', 'attach messenger bag', 'attach nothing', 'attach plastic bags',
'age less 30','age 30 45','age 45 60','age over 60',
'male'
] # 54
'age less 30','age 30 45','age 45 60','age over 60',] # 34

class CustomDataset(Dataset):
def __init__(self, image_root, transform=None):
self.image_root = image_root
print(f"image_root: {image_root}")
self.image_list = sorted([os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith(('.jpg', '.png'))])
self.transform = transform

Expand Down Expand Up @@ -76,35 +77,45 @@ def main(args, image_root):
else:
print(f"Warning: Checkpoint {args.dir} not found, skipping model loading.")
return


print("Starting evaluation...")
start = time.time()

model.eval()
loss_meter = AverageMeter()
# loss_meter = AverageMeter() # ----> used nowhere
preds_probs = []
gt_list = []
imgnames= []
# gt_list = [] # ----> used nowhere
with torch.no_grad():
for step, (imgs, imgname) in enumerate(valid_loader):
for step, (imgs, imgname) in enumerate(data_loader):
imgs = imgs.cuda()
valid_logits,_ = model(imgs, clip_model=clip_model)

valid_logits, _ = model(imgs, clip_model=clip_model)
valid_probs = torch.sigmoid(valid_logits)
preds_probs.append(valid_probs.cpu().numpy())
preds_attrs = [[] for _ in range(len(preds_probs))]

# 这里可以对preds_probs处理后得到预测结果
for pidx, ppreds in enumerate(preds_probs):
for aidx, pattr in enumerate(ppreds):
if pattr >0.45:# 我们的阈值设为0.45 可以修改
preds_attrs[pidx].append(attributes[aidx])

imgnames.append(imgname)

preds_attrs = {}
for batch_preds, batch_imgs in zip(preds_probs, imgnames):
for i, sample_preds in enumerate(batch_preds):
sample_attrs = []
# Now iterate over each attribute score in the sample
for aidx, attr_score in enumerate(sample_preds):
if attr_score > 0.45: # now attr_score is a scalar value
sample_attrs.append(attributes[aidx])


preds_attrs[batch_imgs[i]]= sample_attrs

end = time.time()
print(f'Total test time: {end - start:.2f} seconds')

with open("preds.json" , "w") as f:
json.dump(preds_attrs, f, indent=4)



if __name__ == '__main__':
parser = argument_parser()
image_root = ''
args = parser.parse_args()
main(args)
main(args, image_root)