Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
type='DetectAndRegress',
backbone=None,
pretrained=None,
keypoint_head=None,
human_detector=dict(
type='VoxelCenterDetector',
image_size=image_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
model = dict(
type='DetectAndRegress',
backbone=None,
keypoint_head=None,
pretrained=None,
human_detector=dict(
type='VoxelCenterDetector',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ Results on CMU Panoptic dataset.

| Arch | mAP | mAR | MPJPE | Recall@500mm | ckpt | log |
| :--------------------------------------------------------- | :---: | :---: | :---: | :----------: | :--------------------------------------------------------: | :-------------------------------------------------------: |
| [prn64_cpn80_res50](/configs/body/3d_kpt_mview_rgb_img/voxelpose/panoptic/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5.py) | 97.31 | 97.99 | 17.57 | 99.85 | [ckpt](https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5-545c150e_20211103.pth) | [log](https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5_20211103.log.json) |
| [prn64_cpn80_res50](/configs/body/3d_kpt_mview_rgb_img/voxelpose/panoptic/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5.py) | 97.15 | 97.70 | 17.09 | 99.25 | [ckpt](https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5-358648cb_20230118.pth) | [log](https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5_20230118.log.json) |
Original file line number Diff line number Diff line change
Expand Up @@ -65,44 +65,30 @@
subset='validation'))

# model settings
backbone = dict(
type='AssociativeEmbedding',
pretrained=None,
backbone=dict(type='ResNet', depth=50),
keypoint_head=dict(
type='DeconvHead',
in_channels=2048,
out_channels=num_joints,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
loss_keypoint=dict(
type='MultiLossFactory',
num_joints=15,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=[False],
push_loss_factor=[0.001],
pull_loss_factor=[0.001],
with_heatmaps_loss=[True],
heatmaps_loss_factor=[1.0],
)),
train_cfg=dict(),
test_cfg=dict(
num_joints=num_joints,
nms_kernel=None,
nms_padding=None,
tag_per_joint=None,
max_num_people=None,
detection_threshold=None,
tag_threshold=None,
use_detection_val=None,
ignore_too_much=None,
backbone = dict(type='ResNet', depth=50)
keypoint_head = dict(
type='DeconvHead',
in_channels=2048,
out_channels=num_joints,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
loss_keypoint=dict(
type='MultiLossFactory',
num_joints=15,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=[False],
push_loss_factor=[0.001],
pull_loss_factor=[0.001],
with_heatmaps_loss=[True],
heatmaps_loss_factor=[1.0],
))

model = dict(
type='DetectAndRegress',
backbone=backbone,
keypoint_head=keypoint_head,
pretrained='checkpoints/resnet_50_deconv.pth.tar',
human_detector=dict(
type='VoxelCenterDetector',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ Models:
Results:
- Dataset: CMU Panoptic
Metrics:
MPJPE: 17.57
mAP: 97.31
mAR: 97.99
MPJPE: 17.09
mAP: 97.15
mAR: 97.7
Task: Body 3D Keypoint
Weights: https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5-545c150e_20211103.pth
Weights: https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5-358648cb_20230118.pth
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
model = dict(
type='DetectAndRegress',
backbone=None,
keypoint_head=None,
pretrained=None,
human_detector=dict(
type='VoxelCenterDetector',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
type='DetectAndRegress',
backbone=None,
pretrained=None,
keypoint_head=None,
human_detector=dict(
type='VoxelCenterDetector',
image_size=image_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import json_tricks as json
import numpy as np
from scipy.io import loadmat
from torch.utils.data import Dataset

from mmpose.datasets import DatasetInfo
Expand Down Expand Up @@ -249,8 +248,5 @@ def _load_files(self):

assert osp.exists(self.gt_pose_db_file), f'gt_pose_db_file ' \
f"{self.gt_pose_db_file} doesn't exist, please check again"
gt = loadmat(self.gt_pose_db_file)
self.gt_pose_db = np.array(np.array(
gt['actor3D'].tolist()).tolist()).squeeze()

self.gt_pose_db = np.load(self.gt_pose_db_file)
self.num_persons = len(self.gt_pose_db)
42 changes: 30 additions & 12 deletions mmpose/models/detectors/multiview_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mmpose.core.post_processing.post_transforms import (
affine_transform_torch, get_affine_transform)
from .. import builder
from ..builder import POSENETS
from ..builder import BACKBONES, HEADS, POSENETS
from ..utils.misc import torch_meshgrid_ij
from .base import BasePose

Expand Down Expand Up @@ -138,7 +138,9 @@
"""DetectAndRegress approach for multiview human pose detection.

Args:
backbone (ConfigDict): Dictionary to construct the 2D pose detector
backbone (ConfigDict): Dictionary to construct the backbone.
keypoint_head (ConfigDict): Dictionary to construct the 2d
keypoint head.
human_detector (ConfigDict): dictionary to construct human detector
pose_regressor (ConfigDict): dictionary to construct pose regressor
train_cfg (ConfigDict): Config for training. Default: None.
Expand All @@ -150,6 +152,7 @@

def __init__(self,
backbone,
keypoint_head,
human_detector,
pose_regressor,
train_cfg=None,
Expand All @@ -158,11 +161,16 @@
freeze_2d=True):
super(DetectAndRegress, self).__init__()
if backbone is not None:
self.backbone = builder.build_posenet(backbone)
if self.training and pretrained is not None:
load_checkpoint(self.backbone, pretrained)
self.backbone = BACKBONES.build(backbone)

Check warning on line 164 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L164

Added line #L164 was not covered by tests
else:
self.backbone = None
if keypoint_head is not None:
self.keypoint_head = HEADS.build(keypoint_head)

Check warning on line 168 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L168

Added line #L168 was not covered by tests
else:
self.keypoint_head = None

if self.training and pretrained is not None:
load_checkpoint(self, pretrained)

Check warning on line 173 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L173

Added line #L173 was not covered by tests

self.freeze_2d = freeze_2d
self.human_detector = builder.MODELS.build(human_detector)
Expand All @@ -188,8 +196,11 @@
Module: self
"""
super().train(mode)
if mode and self.freeze_2d and self.backbone is not None:
self._freeze(self.backbone)
if mode and self.freeze_2d:
if self.backbone is not None:
self._freeze(self.backbone)

Check warning on line 201 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L201

Added line #L201 was not covered by tests
if self.keypoint_head is not None:
self._freeze(self.keypoint_head)

Check warning on line 203 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L203

Added line #L203 was not covered by tests

return self

Expand Down Expand Up @@ -283,6 +294,12 @@

return outputs

def predict_heatmap(self, img):
output = self.backbone(img)
output = self.keypoint_head(output)

Check warning on line 299 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L298-L299

Added lines #L298 - L299 were not covered by tests

return output

Check warning on line 301 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L301

Added line #L301 was not covered by tests

def forward_train(self,
img,
img_metas,
Expand Down Expand Up @@ -331,7 +348,7 @@
feature_maps = []
assert isinstance(img, list)
for img_ in img:
feature_maps.append(self.backbone.forward_dummy(img_)[0])
feature_maps.append(self.predict_heatmap(img_)[0])

Check warning on line 351 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L351

Added line #L351 was not covered by tests

losses = dict()
human_candidates, human_loss = self.human_detector.forward_train(
Expand All @@ -351,8 +368,9 @@
heatmaps_tensor = torch.cat(feature_maps, dim=0)
targets_tensor = torch.cat(targets, dim=0)
masks_tensor = torch.cat(masks, dim=0)
losses_2d_ = self.backbone.get_loss(heatmaps_tensor,
targets_tensor, masks_tensor)
losses_2d_ = self.keypoint_head.get_loss(heatmaps_tensor,

Check warning on line 371 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L371

Added line #L371 was not covered by tests
targets_tensor,
masks_tensor)
for k, v in losses_2d_.items():
losses_2d[k + '_2d'] = v
losses.update(losses_2d)
Expand Down Expand Up @@ -400,7 +418,7 @@
feature_maps = []
assert isinstance(img, list)
for img_ in img:
feature_maps.append(self.backbone.forward_dummy(img_)[0])
feature_maps.append(self.predict_heatmap(img_)[0])

Check warning on line 421 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L421

Added line #L421 was not covered by tests

human_candidates = self.human_detector.forward_test(
None, img_metas, feature_maps)
Expand Down Expand Up @@ -506,7 +524,7 @@
feature_maps = []
assert isinstance(img, list)
for img_ in img:
feature_maps.append(self.backbone.forward_dummy(img_)[0])
feature_maps.append(self.predict_heatmap(img_)[0])

Check warning on line 527 in mmpose/models/detectors/multiview_pose.py

View check run for this annotation

Codecov / codecov/patch

mmpose/models/detectors/multiview_pose.py#L527

Added line #L527 was not covered by tests

_ = self.human_detector.forward_dummy(feature_maps)

Expand Down
2 changes: 1 addition & 1 deletion model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ Import:
- configs/face/2d_kpt_sview_rgb_img/topdown_heatmap/wflw/hrnetv2_dark_wflw.yml
- configs/face/2d_kpt_sview_rgb_img/topdown_heatmap/wflw/hrnetv2_wflw.yml
- configs/fashion/2d_kpt_sview_rgb_img/deeppose/deepfashion/resnet_deepfashion.yml
- configs/fashion/2d_kpt_sview_rgb_img/topdown_heatmap/deepfashion/resnet_deepfashion.yml
- configs/fashion/2d_kpt_sview_rgb_img/topdown_heatmap/deepfashion2/resnet_deepfashion2.yml
- configs/fashion/2d_kpt_sview_rgb_img/topdown_heatmap/deepfashion/resnet_deepfashion.yml
- configs/hand/2d_kpt_sview_rgb_img/deeppose/onehand10k/resnet_onehand10k.yml
- configs/hand/2d_kpt_sview_rgb_img/deeppose/panoptic2d/resnet_panoptic2d.yml
- configs/hand/2d_kpt_sview_rgb_img/deeppose/rhd2d/resnet_rhd2d.yml
Expand Down
Binary file removed tests/data/campus/actorsGT.mat
Binary file not shown.
Binary file added tests/data/campus/actorsGT.npy
Binary file not shown.
Binary file removed tests/data/shelf/actorsGT.mat
Binary file not shown.
Binary file added tests/data/shelf/actorsGT.npy
Binary file not shown.
8 changes: 4 additions & 4 deletions tests/test_datasets/test_body3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def test_body3dmview_direct_campus_dataset():
cam_file=f'{data_root}/calibration_campus.json',
train_pose_db_file=f'{data_root}/panoptic_training_pose.pkl',
test_pose_db_file=f'{data_root}/pred_campus_maskrcnn_hrnet_coco.pkl',
gt_pose_db_file=f'{data_root}/actorsGT.mat',
gt_pose_db_file=f'{data_root}/actorsGT.npy',
)

test_data_cfg = dict(
Expand All @@ -398,7 +398,7 @@ def test_body3dmview_direct_campus_dataset():
cam_file=f'{data_root}/calibration_campus.json',
train_pose_db_file=f'{data_root}/panoptic_training_pose.pkl',
test_pose_db_file=f'{data_root}/pred_campus_maskrcnn_hrnet_coco.pkl',
gt_pose_db_file=f'{data_root}/actorsGT.mat',
gt_pose_db_file=f'{data_root}/actorsGT.npy',
)

# test when dataset_info is None
Expand Down Expand Up @@ -507,7 +507,7 @@ def test_body3dmview_direct_shelf_dataset():
cam_file=f'{data_root}/calibration_shelf.json',
train_pose_db_file=f'{data_root}/panoptic_training_pose.pkl',
test_pose_db_file=f'{data_root}/pred_shelf_maskrcnn_hrnet_coco.pkl',
gt_pose_db_file=f'{data_root}/actorsGT.mat',
gt_pose_db_file=f'{data_root}/actorsGT.npy',
)

test_data_cfg = dict(
Expand All @@ -526,7 +526,7 @@ def test_body3dmview_direct_shelf_dataset():
cam_file=f'{data_root}/calibration_shelf.json',
train_pose_db_file=f'{data_root}/panoptic_training_pose.pkl',
test_pose_db_file=f'{data_root}/pred_shelf_maskrcnn_hrnet_coco.pkl',
gt_pose_db_file=f'{data_root}/actorsGT.mat',
gt_pose_db_file=f'{data_root}/actorsGT.npy',
)

# test when dataset_info is None
Expand Down
1 change: 1 addition & 0 deletions tests/test_models/test_multiview_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_voxelpose_forward():
model_cfg = dict(
type='DetectAndRegress',
backbone=None,
keypoint_head=None,
human_detector=dict(
type='VoxelCenterDetector',
image_size=[960, 512],
Expand Down