diff --git a/inference.py b/inference.py index f018fe3..80721b6 100644 --- a/inference.py +++ b/inference.py @@ -27,14 +27,13 @@ def inference(img_path: Path, img_size: tuple[int, int], # Prepare model vit_pose = ViTPose(model_cfg) - ckpt = torch.load(ckpt_path) if 'state_dict' in ckpt: vit_pose.load_state_dict(ckpt['state_dict']) else: vit_pose.load_state_dict(ckpt) - vit_pose.to(device) + vit_pose.to(device).eval() print(f">>> Model loaded: {ckpt_path}") # Prepare input data @@ -92,4 +91,4 @@ def inference(img_path: Path, img_size: tuple[int, int], print(img_path) keypoints = inference(img_path=img_path, img_size=img_size, model_cfg=model_cfg, ckpt_path=CKPT_PATH, device=torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu'), - save_result=True) \ No newline at end of file + save_result=True)