diff --git a/PromptPAR/test_in_custom.py b/PromptPAR/test_in_custom.py index 2c95401..eb181c1 100644 --- a/PromptPAR/test_in_custom.py +++ b/PromptPAR/test_in_custom.py @@ -2,6 +2,7 @@ import pprint from collections import OrderedDict import time +import json import numpy as np import torch from torch.utils.data import DataLoader, Dataset @@ -16,6 +17,7 @@ from torchvision import transforms from PIL import Image from config import argument_parser +from utils.train_utils import AverageMeter set_seed(605) device = "cuda" if torch.cuda.is_available() else "cpu" # attr_words的设置取决于你加载哪个数据集训练的checkpoint 例如PETA @@ -26,13 +28,12 @@ 'lower Casual', 'lower Formal', 'lower Jeans', 'lower Shorts', 'lower Short Skirt','lower Trousers', 'shoes Leather', 'shoes Sandals', 'shoes other', 'shoes sneaker', 'attach Backpack', 'attach Other', 'attach messenger bag', 'attach nothing', 'attach plastic bags', - 'age less 30','age 30 45','age 45 60','age over 60', - 'male' -] # 54 + 'age less 30','age 30 45','age 45 60','age over 60',] # 34 class CustomDataset(Dataset): def __init__(self, image_root, transform=None): self.image_root = image_root + print(f"image_root: {image_root}") self.image_list = sorted([os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith(('.jpg', '.png'))]) self.transform = transform @@ -76,35 +77,45 @@ def main(args, image_root): else: print(f"Warning: Checkpoint {args.dir} not found, skipping model loading.") return - - print("Starting evaluation...") start = time.time() model.eval() - loss_meter = AverageMeter() + # loss_meter = AverageMeter() # ----> used nowhere preds_probs = [] - gt_list = [] + imgnames= [] + # gt_list = [] # ----> used nowhere with torch.no_grad(): - for step, (imgs, imgname) in enumerate(valid_loader): + for step, (imgs, imgname) in enumerate(data_loader): imgs = imgs.cuda() - valid_logits,_ = model(imgs, clip_model=clip_model) - + valid_logits, _ = model(imgs, clip_model=clip_model) valid_probs = torch.sigmoid(valid_logits) preds_probs.append(valid_probs.cpu().numpy()) - preds_attrs = [[] for _ in range(len(preds_probs))] - - # 这里可以对preds_probs处理后得到预测结果 - for pidx, ppreds in enumerate(preds_probs): - for aidx, pattr in enumerate(ppreds): - if pattr >0.45:# 我们的阈值设为0.45 可以修改 - preds_attrs[pidx].append(attributes[aidx]) - + imgnames.append(imgname) + + preds_attrs = {} + for batch_preds, batch_imgs in zip(preds_probs, imgnames): + for i, sample_preds in enumerate(batch_preds): + sample_attrs = [] + # Now iterate over each attribute score in the sample + for aidx, attr_score in enumerate(sample_preds): + if attr_score > 0.45: # now attr_score is a scalar value + sample_attrs.append(attributes[aidx]) + + + preds_attrs[batch_imgs[i]]= sample_attrs + end = time.time() print(f'Total test time: {end - start:.2f} seconds') + + with open("preds.json" , "w") as f: + json.dump(preds_attrs, f, indent=4) + + if __name__ == '__main__': parser = argument_parser() image_root = '' args = parser.parse_args() - main(args) + main(args, image_root) +