diff --git a/datasetsss.py b/datasetsss.py index 1c7835d..7051838 100644 --- a/datasetsss.py +++ b/datasetsss.py @@ -31,7 +31,7 @@ def __init__(self, img_dir, mode): self.audio_feats = self.audio_feats.astype(np.float32) def __len__(self): - return self.audio_feats.shape[0] if self.audio_feats[0] features.shape[0]: + pad_right = right - features.shape[0] + right = features.shape[0] + auds = features[left:right].copy() # Ensure we get a copy, not a view + if pad_left > 0: + auds = np.concatenate([np.zeros_like(auds[:pad_left]), auds], axis=0) + if pad_right > 0: + auds = np.concatenate([auds, np.zeros_like(auds[:pad_right])], axis=0) # [8, 16] + return auds + + +audio_feats = np.load(audio_feat_path) +img_dir = os.path.join(dataset_dir, "full_body_img/") +lms_dir = os.path.join(dataset_dir, "landmarks/") +len_img = len(os.listdir(img_dir)) - 1 +exm_img = cv2.imread(img_dir + "0.jpg") +h, w = exm_img.shape[:2] + +if mode == "hubert": + video_writer = cv2.VideoWriter( + save_path, cv2.VideoWriter_fourcc("M", "J", "P", "G"), 25, (w, h) + ) +if mode == "wenet": + video_writer = cv2.VideoWriter( + save_path, cv2.VideoWriter_fourcc("M", "J", "P", "G"), 20, (w, h) + ) +step_stride = 0 +img_idx = 0 + + +unet = UnetTRT(checkpoint) + +import time + +s0 = time.time() + +for i in range(audio_feats.shape[0]): + if img_idx > len_img - 1: + step_stride = ( + -1 + ) # step_stride 决定取图片的间隔,目前这个逻辑是从头开始一张一张往后,到最后一张后再一张一张往前 + if img_idx < 1: + step_stride = 1 + img_idx += step_stride + img_path = img_dir + str(img_idx) + ".jpg" + lms_path = lms_dir + str(img_idx) + ".lms" + + img = cv2.imread(img_path) + lms_list = [] + with open(lms_path, "r") as f: + lines = f.read().splitlines() + for line in lines: + arr = line.split(" ") + arr = np.array(arr, dtype=np.float32) + lms_list.append(arr) + lms = np.array(lms_list, dtype=np.int32) # 这个关键点检测模型之后之后可能会改掉 + xmin = lms[1][0] + ymin = lms[52][1] + + xmax = lms[31][0] + width = xmax - xmin + ymax = ymin + width + crop_img = img[ymin:ymax, xmin:xmax] + h, w = crop_img.shape[:2] + crop_img = cv2.resize(crop_img, (168, 168), cv2.INTER_AREA) + crop_img_ori = crop_img.copy() + img_real_ex = crop_img[4:164, 4:164].copy() + img_real_ex_ori = img_real_ex.copy() + img_masked = cv2.rectangle(img_real_ex_ori, (5, 5, 150, 145), (0, 0, 0), -1) + + img_masked = img_masked.transpose(2, 0, 1).astype(np.float32) + img_real_ex = img_real_ex.transpose(2, 0, 1).astype(np.float32) + img_real_ex_T = (img_real_ex / 255.0).astype(np.float32) + img_masked_T = (img_masked / 255.0).astype(np.float32) + img_concat_T = np.concatenate([img_real_ex_T, img_masked_T], axis=0)[np.newaxis] + + audio_feat = get_audio_features(audio_feats, i) + if mode == "hubert": + audio_feat = audio_feat.reshape(16, 32, 32) + if mode == "wenet": + audio_feat = audio_feat.reshape(128, 16, 32) + audio_feat = audio_feat[None] + + output_host = unet(img_concat_T, audio_feat) + pred = np.squeeze(output_host, 0).transpose(1, 2, 0) * 255.0 + + pred = np.array(pred, dtype=np.uint8) + crop_img_ori[4:164, 4:164] = pred + crop_img_ori = cv2.resize(crop_img_ori, (w, h)) + img[ymin:ymax, xmin:xmax] = crop_img_ori +# video_writer.write(img) +#video_writer.release() + +print(audio_feats.shape[0] / (time.time() - s0)) +print(time.time() - s0) +# ffmpeg -i test_video.mp4 -i test_audio.pcm -c:v libx264 -c:a aac result_test.mp4 diff --git a/train.py b/train.py index 1b495ba..2a46b36 100644 --- a/train.py +++ b/train.py @@ -117,8 +117,8 @@ def train(net, epoch, batch_size, lr): optimizer.step() p.update(imgs.shape[0]) - if e % 5 == 0: - torch.save(net.state_dict(), os.path.join(save_dir, str(e)+'.pth')) + if (e + 1) % 5 == 0: + torch.save(net.state_dict(), os.path.join(save_dir, str(e + 1).zfill(5)+'.pth')) if args.see_res: net.eval() img_concat_T, img_real_T, audio_feat = dataset.__getitem__(random.randint(0, dataset.__len__())) @@ -139,4 +139,4 @@ def train(net, epoch, batch_size, lr): net = Model(6, args.asr).to(device) - train(net, args.epochs, args.batchsize, args.lr) \ No newline at end of file + train(net, args.epochs, args.batchsize, args.lr)