diff --git a/INSTALL.md b/INSTALL.md
index 5a722fb..198e7c7 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -14,7 +14,7 @@ The DensePose-RCNN system is implemented within the [`detectron`](https://github
## Caffe2
-To install Caffe2 with CUDA support, follow the [installation instructions](https://caffe2.ai/docs/getting-started.html) from the [Caffe2 website](https://caffe2.ai/). **If you already have Caffe2 installed, make sure to update your Caffe2 to a version that includes the [Detectron module](https://github.com/caffe2/caffe2/tree/master/modules/detectron).**
+To install Caffe2 with CUDA support, follow the [installation instructions](https://caffe2.ai/docs/getting-started.html) from the [Caffe2 website](https://caffe2.ai/). **If you already have Caffe2 installed, make sure to update your Caffe2 to a version that includes the [Detectron module](https://github.com/pytorch/pytorch/tree/master/modules/detectron).**
Please ensure that your Caffe2 installation was successful before proceeding by running the following commands and checking their output as directed in the comments.
@@ -130,7 +130,7 @@ coco
## Docker Image
-We provide a [`Dockerfile`](docker/Dockerfile) that you can use to build a Densepose image on top of a Caffe2 image that satisfies the requirements outlined at the top. If you would like to use a Caffe2 image different from the one we use by default, please make sure that it includes the [Detectron module](https://github.com/caffe2/caffe2/tree/master/modules/detectron).
+We provide a [`Dockerfile`](docker/Dockerfile) that you can use to build a Densepose image on top of a Caffe2 image that satisfies the requirements outlined at the top. If you would like to use a Caffe2 image different from the one we use by default, please make sure that it includes the [Detectron module](https://github.com/pytorch/pytorch/tree/master/modules/detectron).
Build the image:
diff --git a/README.md b/README.md
index 33a1b1f..f5b6f03 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@ _Rıza Alp Güler, Natalia Neverova, Iasonas Kokkinos_
[[`densepose.org`](https://densepose.org)] [[`arXiv`](https://arxiv.org/abs/1802.00434)] [[`BibTeX`](#CitingDensePose)]
Dense human pose estimation aims at mapping all human pixels of an RGB image to the 3D surface of the human body.
-DensePose-RCNN is implemented in the [Detectron](https://github.com/facebookresearch/Detectron) framework and is powered by [Caffe2](https://github.com/caffe2/caffe2).
+DensePose-RCNN is implemented in the [Detectron](https://github.com/facebookresearch/Detectron) framework and is powered by [Caffe2](https://github.com/pytorch/pytorch/tree/master/caffe2).

diff --git a/detectron/core/config.py b/detectron/core/config.py
index 79ee36f..e165e37 100644
--- a/detectron/core/config.py
+++ b/detectron/core/config.py
@@ -863,7 +863,7 @@
__C.BODY_UV_RCNN.UP_SCALE = -1
# Apply a ConvTranspose layer to the features prior to predicting the heatmaps
-__C.KRCNN.USE_DECONV = False
+__C.BODY_UV_RCNN.USE_DECONV = False
# Channel dimension of the hidden representation produced by the ConvTranspose
__C.BODY_UV_RCNN.DECONV_DIM = 256
# Use a ConvTranspose layer to predict the heatmaps
@@ -876,6 +876,9 @@
# Number of patches in the dataset
__C.BODY_UV_RCNN.NUM_PATCHES = -1
+# Number of semantic parts used to sample annotation points
+__C.BODY_UV_RCNN.NUM_SEMANTIC_PARTS = 14
+
# Number of stacked Conv layers in body UV head
__C.BODY_UV_RCNN.NUM_STACKED_CONVS = 8
# Dimension of the hidden representation output by the body UV head
diff --git a/detectron/core/test.py b/detectron/core/test.py
index 877d803..4fc5854 100644
--- a/detectron/core/test.py
+++ b/detectron/core/test.py
@@ -948,7 +948,7 @@ def im_detect_body_uv(model, im_scale, boxes):
# Removed squeeze calls due to singleton dimension issues
CurAnnIndex = np.argmax(CurAnnIndex, axis=0)
CurIndex_UV = np.argmax(CurIndex_UV, axis=0)
- CurIndex_UV = CurIndex_UV * (CurAnnIndex>0).astype(np.float32)
+ CurIndex_UV = CurIndex_UV * (CurAnnIndex > 0).astype(np.float32)
output = np.zeros([3, int(by), int(bx)], dtype=np.float32)
output[0] = CurIndex_UV
@@ -956,8 +956,8 @@ def im_detect_body_uv(model, im_scale, boxes):
for part_id in range(1, K):
CurrentU = CurU_uv[part_id]
CurrentV = CurV_uv[part_id]
- output[1, CurIndex_UV==part_id] = CurrentU[CurIndex_UV==part_id]
- output[2, CurIndex_UV==part_id] = CurrentV[CurIndex_UV==part_id]
+ output[1, CurIndex_UV == part_id] = CurrentU[CurIndex_UV == part_id]
+ output[2, CurIndex_UV == part_id] = CurrentV[CurIndex_UV == part_id]
outputs.append(output)
num_classes = cfg.MODEL.NUM_CLASSES
diff --git a/detectron/core/test_engine.py b/detectron/core/test_engine.py
index dc7c92c..bc5d76a 100644
--- a/detectron/core/test_engine.py
+++ b/detectron/core/test_engine.py
@@ -374,7 +374,7 @@ def get_roidb_and_dataset(dataset_name, proposal_file, ind_range):
def empty_results(num_classes, num_images):
- """Return empty results lists for boxes, masks, and keypoints.
+ """Return empty results lists for boxes, masks, keypoints and body IUVs.
Box detections are collected into:
all_boxes[cls][image] = N x 5 array with columns (x1, y1, x2, y2, score)
Instance mask predictions are collected into:
@@ -386,8 +386,11 @@ def empty_results(num_classes, num_images):
[x, y, logit, prob] (See: utils.keypoints.heatmaps_to_keypoints).
Keypoints are recorded for person (cls = 1); they are in 1:1
correspondence with the boxes in all_boxes[cls][image].
- Body uv predictions are collected into:
- TODO
+ Body IUV predictions are collected into:
+ all_bodys[cls][image] = [...] list of body IUV results that are in 1:1
+ correspondence with the boxes in all_boxes['person'][image], each encoded
+ as a 3D array (3, int(bbox_height), int(bbox_width)) with the 3 rows
+ corresponding to [Index, U, V].
"""
# Note: do not be tempted to use [[] * N], which gives N references to the
# *same* empty list.
diff --git a/detectron/datasets/densepose_cocoeval.py b/detectron/datasets/densepose_cocoeval.py
index b082a12..6bff367 100644
--- a/detectron/datasets/densepose_cocoeval.py
+++ b/detectron/datasets/densepose_cocoeval.py
@@ -97,49 +97,52 @@ def __init__(self, cocoGt=None, cocoDt=None, iouType='segm', sigma=1.):
self.sigma = sigma
self.ignoreThrBB = 0.7
self.ignoreThrUV = 0.9
+ self.num_parts = 24 # number of pre-defined body parts
def _loadGEval(self):
- print('Loading densereg GT..')
prefix = os.path.dirname(__file__) + '/../../DensePoseData/eval_data/'
- print(prefix)
+ print('Loading densereg GT from {}'.format(prefix))
SMPL_subdiv = loadmat(prefix + 'SMPL_subdiv.mat')
self.PDIST_transform = loadmat(prefix + 'SMPL_SUBDIV_TRANSFORM.mat')
+ # 1-based index of geodesic distance matrix,
+ # range: [1, 27554], shape: (num_total_points=29408,)
self.PDIST_transform = self.PDIST_transform['index'].squeeze()
- UV = np.array([
- SMPL_subdiv['U_subdiv'],
- SMPL_subdiv['V_subdiv']
- ]).squeeze()
- ClosestVertInds = np.arange(UV.shape[1])+1
+ # UV coordinates of all collected points on SMPL model, shape: (2, num_total_points)
+ UV = np.array([SMPL_subdiv['U_subdiv'], SMPL_subdiv['V_subdiv']]).squeeze()
+ # body part index (1 ~ 24), shape: (num_total_point,)
+ Part_ids = SMPL_subdiv['Part_ID_subdiv'].squeeze()
+ self.Part_ids = np.array(Part_ids)
+ # 1-based index of closest vertex index of each point,
+ # range: [1, num_total_points], shape: (num_total_points,)
+ ClosestVertInds = np.arange(UV.shape[1]) + 1
self.Part_UVs = []
self.Part_ClosestVertInds = []
- for i in np.arange(24):
- self.Part_UVs.append(
- UV[:, SMPL_subdiv['Part_ID_subdiv'].squeeze()==(i+1)]
- )
- self.Part_ClosestVertInds.append(
- ClosestVertInds[SMPL_subdiv['Part_ID_subdiv'].squeeze()==(i+1)]
- )
+ # UV coordinates and closest vertex indices of points in each part
+ for i in np.arange(1, self.num_parts + 1):
+ self.Part_UVs.append(UV[:, Part_ids == i])
+ self.Part_ClosestVertInds.append(ClosestVertInds[Part_ids == i])
arrays = {}
- f = h5py.File( prefix + 'Pdist_matrix.mat')
+ f = h5py.File(prefix + 'Pdist_matrix.mat')
for k, v in f.items():
arrays[k] = np.array(v)
+ f.close()
+ # precomputed geodesic distances matrix with compact representation
self.Pdist_matrix = arrays['Pdist_matrix']
- self.Part_ids = np.array( SMPL_subdiv['Part_ID_subdiv'].squeeze())
# Mean geodesic distances for parts.
- self.Mean_Distances = np.array( [0, 0.351, 0.107, 0.126,0.237,0.173,0.142,0.128,0.150] )
+ self.Mean_Distances = np.array([0, 0.351, 0.107, 0.126, 0.237, 0.173, 0.142, 0.128, 0.150])
# Coarse Part labels.
- self.CoarseParts = np.array( [ 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5,
- 6, 6, 6, 6, 7, 7, 7, 7, 8, 8] )
-
- print('Loaded')
+ self.CoarseParts = np.array(
+ [0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8]
+ )
+
+ print('densereg GT loaded')
def _prepare(self):
'''
Prepare ._gts and ._dts for evaluation based on params
:return: None
'''
-
def _toMask(anns, coco):
# modify ann['segmentation'] by reference
for ann in anns:
@@ -168,11 +171,11 @@ def _checkIgnore(dt, iregion):
return True
bb = np.array(dt['bbox']).astype(np.int)
- x1,y1,x2,y2 = bb[0],bb[1],bb[0]+bb[2],bb[1]+bb[3]
- x2 = min([x2,iregion.shape[1]])
- y2 = min([y2,iregion.shape[0]])
+ x1, y1, x2, y2 = bb[0], bb[1], bb[0] + bb[2], bb[1] + bb[3]
+ x2 = min([x2, iregion.shape[1]])
+ y2 = min([y2, iregion.shape[0]])
- if bb[2]* bb[3] == 0:
+ if bb[2] * bb[3] == 0:
return False
crop_iregion = iregion[y1:y2, x1:x2]
@@ -181,11 +184,11 @@ def _checkIgnore(dt, iregion):
return True
if not 'uv' in dt.keys(): # filtering boxes
- return crop_iregion.sum()/bb[2]/bb[3] < self.ignoreThrBB
+ return crop_iregion.sum() / bb[2] / bb[3] < self.ignoreThrBB
# filtering UVs
ignoremask = np.require(crop_iregion, requirements=['F'])
- uvmask = np.require(np.asarray(dt['uv'][0]>0), dtype = np.uint8,
+ uvmask = np.require(np.asarray(dt['uv'][0] > 0), dtype=np.uint8,
requirements=['F'])
uvmask_ = maskUtils.encode(uvmask)
ignoremask_ = maskUtils.encode(ignoremask)
@@ -195,13 +198,13 @@ def _checkIgnore(dt, iregion):
p = self.params
if p.useCats:
- gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
- dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
+ gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
+ dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
else:
- gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
- dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
+ gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
+ dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
- # if iouType == 'uv', add point gt annotations
+ # add point gt annotations if iouType == 'uv'
if p.iouType == 'uv':
self._loadGEval()
@@ -217,7 +220,7 @@ def _checkIgnore(dt, iregion):
if p.iouType == 'keypoints':
gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore']
if p.iouType == 'uv':
- gt['ignore'] = ('dp_x' in gt)==0
+ gt['ignore'] = 'dp_x' not in gt
self._gts = defaultdict(list) # gt for evaluation
self._dts = defaultdict(list) # dt for evaluation
@@ -233,8 +236,8 @@ def _checkIgnore(dt, iregion):
if _checkIgnore(dt, self._igrgns[dt['image_id']]):
self._dts[dt['image_id'], dt['category_id']].append(dt)
- self.evalImgs = defaultdict(list) # per-image per-category evaluation results
- self.eval = {} # accumulated evaluation results
+ self.evalImgs = defaultdict(list) # per-image per-category evaluation results
+ self.eval = {} # accumulated evaluation results
def evaluate(self):
'''
@@ -253,7 +256,7 @@ def evaluate(self):
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)
- self.params=p
+ self.params = p
self._prepare()
# loop through images, area range, max detection number
@@ -264,37 +267,40 @@ def evaluate(self):
elif p.iouType == 'keypoints':
computeIoU = self.computeOks
elif p.iouType == 'uv':
- computeIoU = self.computeOgps
+ computeIoU = self.computeGps
- self.ious = {(imgId, catId): computeIoU(imgId, catId) \
- for imgId in p.imgIds
- for catId in catIds}
+ self.ious = {
+ (imgId, catId): computeIoU(imgId, catId)
+ for imgId in p.imgIds
+ for catId in catIds
+ }
evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
- self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet)
- for catId in catIds
- for areaRng in p.areaRng
- for imgId in p.imgIds
- ]
+ self.evalImgs = [
+ evaluateImg(imgId, catId, areaRng, maxDet)
+ for catId in catIds
+ for areaRng in p.areaRng
+ for imgId in p.imgIds
+ ]
self._paramsEval = copy.deepcopy(self.params)
toc = time.time()
- print('DONE (t={:0.2f}s).'.format(toc-tic))
+ print('DONE (t={:0.2f}s).'.format(toc - tic))
def computeIoU(self, imgId, catId):
p = self.params
if p.useCats:
- gt = self._gts[imgId,catId]
- dt = self._dts[imgId,catId]
+ gt = self._gts[imgId, catId]
+ dt = self._dts[imgId, catId]
else:
- gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
- dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
- if len(gt) == 0 and len(dt) ==0:
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
+ if len(gt) == 0 and len(dt) == 0:
return []
inds = np.argsort([-d['score'] for d in dt], kind='mergesort')
dt = [dt[i] for i in inds]
if len(dt) > p.maxDets[-1]:
- dt=dt[0:p.maxDets[-1]]
+ dt = dt[0:p.maxDets[-1]]
if p.iouType == 'segm':
g = [g['segmentation'] for g in gt]
@@ -312,7 +318,7 @@ def computeIoU(self, imgId, catId):
def computeOks(self, imgId, catId):
p = self.params
- # dimention here should be Nxm
+ # dimention here should be N x m
gts = self._gts[imgId, catId]
dts = self._dts[imgId, catId]
inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
@@ -323,8 +329,8 @@ def computeOks(self, imgId, catId):
if len(gts) == 0 or len(dts) == 0:
return []
ious = np.zeros((len(dts), len(gts)))
- sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0
- vars = (sigmas * 2)**2
+ sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
+ vars = (sigmas * 2) ** 2
k = len(sigmas)
# compute oks between each detection and ground truth object
for j, gt in enumerate(gts):
@@ -338,83 +344,91 @@ def computeOks(self, imgId, catId):
for i, dt in enumerate(dts):
d = np.array(dt['keypoints'])
xd = d[0::3]; yd = d[1::3]
- if k1>0:
+ if k1 > 0:
# measure the per-keypoint distance if keypoints visible
dx = xd - xg
dy = yd - yg
else:
# measure minimum distance to keypoints in (x0,y0) & (x1,y1)
z = np.zeros((k))
- dx = np.max((z, x0-xd), axis=0) + np.max((z, xd-x1), axis=0)
- dy = np.max((z, y0-yd), axis=0) + np.max((z, yd-y1), axis=0)
- e = (dx**2 + dy**2) / vars / (gt['area'] + np.spacing(1)) / 2
+ dx = np.max((z, x0 - xd), axis=0) + np.max((z, xd - x1), axis=0)
+ dy = np.max((z, y0 - yd), axis=0) + np.max((z, yd - y1), axis=0)
+ e = (dx ** 2 + dy ** 2) / vars / (gt['area'] + np.spacing(1)) / 2
if k1 > 0:
- e=e[vg > 0]
+ e = e[vg > 0]
ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
return ious
- def computeOgps(self, imgId, catId):
+ def computeGps(self, imgId, catId):
p = self.params
- # dimention here should be Nxm
- g = self._gts[imgId, catId]
- d = self._dts[imgId, catId]
- inds = np.argsort([-d_['score'] for d_ in d], kind='mergesort')
- d = [d[i] for i in inds]
- if len(d) > p.maxDets[-1]:
- d = d[0:p.maxDets[-1]]
- # if len(gts) == 0 and len(dts) == 0:
- if len(g) == 0 or len(d) == 0:
+ # dimention here should be N x m
+ gt = self._gts[imgId, catId]
+ dt = self._dts[imgId, catId]
+ inds = np.argsort([-d['score'] for d in dt], kind='mergesort')
+ dt = [dt[i] for i in inds]
+ if len(dt) > p.maxDets[-1]:
+ dt = dt[0:p.maxDets[-1]]
+
+ if len(gt) == 0 or len(dt) == 0:
return []
- ious = np.zeros((len(d), len(g)))
- # compute opgs between each detection and ground truth object
- sigma = self.sigma #0.255 # dist = 0.3m corresponds to ogps = 0.5
- # 1 # dist = 0.3m corresponds to ogps = 0.96
- # 1.45 # dist = 1.7m (person height) corresponds to ogps = 0.5)
- for j, gt in enumerate(g):
- if not gt['ignore']:
- g_ = gt['bbox']
- for i, dt in enumerate(d):
- #
- dx = dt['bbox'][3]
- dy = dt['bbox'][2]
- dp_x = np.array( gt['dp_x'] )*g_[2]/255.
- dp_y = np.array( gt['dp_y'] )*g_[3]/255.
- px = ( dp_y + g_[1] - dt['bbox'][1]).astype(np.int)
- py = ( dp_x + g_[0] - dt['bbox'][0]).astype(np.int)
- #
- pts = np.zeros(len(px))
- pts[px>=dx] = -1; pts[py>=dy] = -1
- pts[px<0] = -1; pts[py<0] = -1
- #print(pts.shape)
- if len(pts) < 1:
- ogps = 0.
- elif np.max(pts) == -1:
- ogps = 0.
- else:
- px[pts==-1] = 0; py[pts==-1] = 0;
- ipoints = dt['uv'][0, px, py]
- upoints = dt['uv'][1, px, py]/255. # convert from uint8 by /255.
- vpoints = dt['uv'][2, px, py]/255.
- ipoints[pts==-1] = 0
- ## Find closest vertices in subsampled mesh.
- cVerts, cVertsGT = self.findAllClosestVerts(gt, upoints, vpoints, ipoints)
- ## Get pairwise geodesic distances between gt and estimated mesh points.
- dist = self.getDistances(cVertsGT, cVerts)
- ## Compute the Ogps measure.
+ ious = np.zeros((len(dt), len(gt)))
+
+ # compute gps between each detection and ground truth object
+ # sigma = self.sigma # 0.255 # dist = 0.3m corresponds to gps = 0.5
+ # 1 # dist = 0.3m corresponds to gps = 0.96
+ # 1.45 # dist = 1.7m (person height) corresponds to gps = 0.5)
+ for j, g in enumerate(gt):
+ # gps between any detection and an ignored ground truth person is 0
+ if g['ignore']:
+ continue
+ bb = g['bbox']
+ for i, d in enumerate(dt):
+ # width and height of detected bbox
+ dx = d['bbox'][2]; dy = d['bbox'][3]
+ # (2D) spatial coordinates of annotated points within current gt bbox on the image,
+ # which are scaled such that the gt bbox size is 256 x 256.
+ dp_x = np.array(g['dp_x']) * bb[2] / 255.
+ dp_y = np.array(g['dp_y']) * bb[3] / 255.
+ # spatial coordinates of annotated points relative to detected bbox
+ px = (dp_x + bb[0] - d['bbox'][0]).astype(np.int)
+ py = (dp_y + bb[1] - d['bbox'][1]).astype(np.int)
+ pts = np.zeros(len(dp_x)) # len(dp_x): number of annotated points
+ # annotated points outside the range of detected bbox are considered as background
+ pts[px >= dx] = -1; pts[py >= dy] = -1
+ pts[px < 0] = -1; pts[py < 0] = -1
+ # print("#collected gt points: ", len(dp_x))
+ if len(dp_x) == 0:
+ gps = 0.
+ elif pts.max() == -1:
+ gps = 0.
+ else:
+ px[pts == -1] = 0; py[pts == -1] = 0;
+ ipoints = d['uv'][0, py, px]
+ upoints = d['uv'][1, py, px] / 255. # convert from uint8 by /255.
+ vpoints = d['uv'][2, py, px] / 255.
+ ipoints[pts == -1] = 0
+ # Find closest vertices index in subsampled mesh.
+ cVertInds, cVertIndsGT = self.findAllClosestVertInds(g, upoints, vpoints, ipoints)
+ # Get pairwise geodesic distances between GT and estimated mesh points.
+ dist = self.getDistances(cVertInds, cVertIndsGT)
+ # Compute the GPS measure.
+ if len(dist) > 0:
# Find the mean geodesic normalization distance for each GT point, based on which part it is on.
- Current_Mean_Distances = self.Mean_Distances[ self.CoarseParts[ self.Part_ids [ cVertsGT[cVertsGT>0].astype(int)-1] ] ]
+ Current_Mean_Distances = self.Mean_Distances[
+ self.CoarseParts[self.Part_ids[cVertIndsGT[cVertIndsGT > 0] - 1]]
+ ]
# Compute gps
- ogps_values = np.exp(-(dist**2)/(2*(Current_Mean_Distances**2)))
- #
- if len(dist)>0:
- ogps = np.sum(ogps_values)/ len(dist)
- ious[i, j] = ogps
+ gps = np.exp(-(dist ** 2) / (2 * (Current_Mean_Distances ** 2)))
+ gps = np.sum(gps) / len(dist)
+ else:
+ gps = 0.
+ ious[i, j] = gps
- gbb = [gt['bbox'] for gt in g]
- dbb = [dt['bbox'] for dt in d]
+ gbb = [g['bbox'] for g in gt]
+ dbb = [d['bbox'] for d in dt]
# compute iou between each dt and gt region
- iscrowd = [int(o['iscrowd']) for o in g]
+ iscrowd = [int(o['iscrowd']) for o in gt]
ious_bb = maskUtils.iou(dbb, gbb, iscrowd)
return ious, ious_bb
@@ -423,23 +437,21 @@ def evaluateImg(self, imgId, catId, aRng, maxDet):
perform evaluation for single category and image
:return: dict (single image results)
'''
-
p = self.params
if p.useCats:
- gt = self._gts[imgId,catId]
- dt = self._dts[imgId,catId]
+ gt = self._gts[imgId, catId]
+ dt = self._dts[imgId, catId]
else:
- gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
- dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
if len(gt) == 0 and len(dt) == 0:
return None
for g in gt:
- #g['_ignore'] = g['ignore']
- if g['ignore'] or (g['area']
aRng[1]):
- g['_ignore'] = True
+ if g['ignore'] or (g['area'] < aRng[0] or g['area'] > aRng[1]):
+ g['_ignore'] = 1
else:
- g['_ignore'] = False
+ g['_ignore'] = 0
# sort dt highest score first, sort gt ignore last
gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
@@ -449,9 +461,6 @@ def evaluateImg(self, imgId, catId, aRng, maxDet):
iscrowd = [int(o['iscrowd']) for o in gt]
# load computed ious
if p.iouType == 'uv':
- #print('Checking the length', len(self.ious[imgId, catId]))
- #if len(self.ious[imgId, catId]) == 0:
- # print(self.ious[imgId, catId])
ious = self.ious[imgId, catId][0][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId]
ioubs = self.ious[imgId, catId][1][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId]
else:
@@ -460,31 +469,32 @@ def evaluateImg(self, imgId, catId, aRng, maxDet):
T = len(p.iouThrs)
G = len(gt)
D = len(dt)
- gtm = np.zeros((T,G))
- dtm = np.zeros((T,D))
+ gtm = np.zeros((T, G))
+ dtm = np.zeros((T, D))
gtIg = np.array([g['_ignore'] for g in gt])
- dtIg = np.zeros((T,D))
+ dtIg = np.zeros((T, D))
if np.all(gtIg) == True and p.iouType == 'uv':
dtIg = np.logical_or(dtIg, True)
- if len(ious)>0: # and not p.iouType == 'uv':
+ if len(ious) > 0:
for tind, t in enumerate(p.iouThrs):
for dind, d in enumerate(dt):
# information about best match so far (m=-1 -> unmatched)
- iou = min([t,1-1e-10])
+ iou = min([t, 1 - 1e-10])
m = -1
for gind, g in enumerate(gt):
# if this gt already matched, and not a crowd, continue
- if gtm[tind,gind]>0 and not iscrowd[gind]:
+ if gtm[tind, gind] > 0 and not iscrowd[gind]:
continue
# if dt matched to reg gt, and on ignore gt, stop
- if m>-1 and gtIg[m]==0 and gtIg[gind]==1:
+ if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1:
break
# continue to next gt unless better match made
- if ious[dind,gind] < iou:
- continue
- if ious[dind,gind] == 0.:
+ if ious[dind, gind] < iou:
continue
+ ## redundant condition after the above one
+ # if ious[dind, gind] == 0.:
+ # continue
# if match successful and best so far, store appropriately
iou = ious[dind, gind]
m = gind
@@ -495,49 +505,54 @@ def evaluateImg(self, imgId, catId, aRng, maxDet):
dtm[tind, dind] = gt[m]['id']
gtm[tind, m] = d['id']
- if p.iouType == 'uv':
- if not len(ioubs)==0:
- for dind, d in enumerate(dt):
- # information about best match so far (m=-1 -> unmatched)
- if dtm[tind, dind] == 0:
- ioub = 0.8
- m = -1
- for gind, g in enumerate(gt):
- # if this gt already matched, and not a crowd, continue
- if gtm[tind,gind]>0 and not iscrowd[gind]:
- continue
- # continue to next gt unless better match made
- if ioubs[dind,gind] < ioub:
- continue
- # if match successful and best so far, store appropriately
- ioub = ioubs[dind,gind]
- m = gind
- # if match made store id of match for both dt and gt
- if m > -1:
- dtIg[:, dind] = gtIg[m]
- if gtIg[m]:
- dtm[tind, dind] = gt[m]['id']
- gtm[tind, m] = d['id']
+ """
+ When evaluating for body_uv, suppressing/ignoring a detected box (`dbox`) at all GPS thresholds
+ which satisfies the following criterion:
+ GPS(dbox, gbox) = 0 while IoU(dbox, gbox) > 0.8, where `gbox` is an ignored gt box.
+ """
+ if p.iouType == 'uv' and len(ioubs) > 0:
+ for dind, d in enumerate(dt):
+ # information about best match so far (m=-1 -> unmatched)
+ if dtm[tind, dind] == 0:
+ ioub = 0.8 # a manually set IoU threshold
+ m = -1
+ for gind, g in enumerate(gt):
+ # if this gt already matched, and not a crowd, continue
+ if gtm[tind, gind] > 0 and not iscrowd[gind]:
+ continue
+ # continue to next gt unless better match made
+ if ioubs[dind, gind] < ioub:
+ continue
+ # if match successful and best so far, store appropriately
+ ioub = ioubs[dind, gind]
+ m = gind
+ # if match made store id of match for both dt and gt
+ if m == -1:
+ continue
+ dtIg[:, dind] = gtIg[m]
+ if gtIg[m]:
+ dtm[tind, dind] = gt[m]['id']
+ gtm[tind, m] = d['id']
+
# set unmatched detections outside of area range to ignore
- a = np.array([d['area']aRng[1] for d in dt]).reshape((1, len(dt)))
- dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0)))
+ a = np.array([d['area'] < aRng[0] or d['area'] > aRng[1] for d in dt]).reshape((1, len(dt)))
+ dtIg = np.logical_or(dtIg, np.logical_and(dtm == 0, np.repeat(a, T, 0)))
# store results for given image and category
- #print('Done with the function', len(self.ious[imgId, catId]))
return {
- 'image_id': imgId,
- 'category_id': catId,
- 'aRng': aRng,
- 'maxDet': maxDet,
- 'dtIds': [d['id'] for d in dt],
- 'gtIds': [g['id'] for g in gt],
- 'dtMatches': dtm,
- 'gtMatches': gtm,
- 'dtScores': [d['score'] for d in dt],
- 'gtIgnore': gtIg,
- 'dtIgnore': dtIg,
- }
-
- def accumulate(self, p = None):
+ 'image_id': imgId,
+ 'category_id': catId,
+ 'aRng': aRng,
+ 'maxDet': maxDet,
+ 'dtIds': [d['id'] for d in dt],
+ 'gtIds': [g['id'] for g in gt],
+ 'dtMatches': dtm,
+ 'gtMatches': gtm,
+ 'dtScores': [d['score'] for d in dt],
+ 'gtIgnore': gtIg,
+ 'dtIgnore': dtIg,
+ }
+
+ def accumulate(self, p=None):
'''
Accumulate per image evaluation results and store the result in self.eval
:param p: input params for evaluation
@@ -551,16 +566,17 @@ def accumulate(self, p = None):
if p is None:
p = self.params
p.catIds = p.catIds if p.useCats == 1 else [-1]
- T = len(p.iouThrs)
- R = len(p.recThrs)
- K = len(p.catIds) if p.useCats else 1
- A = len(p.areaRng)
- M = len(p.maxDets)
- precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories
- recall = -np.ones((T,K,A,M))
+ T = len(p.iouThrs)
+ R = len(p.recThrs)
+ K = len(p.catIds) if p.useCats else 1
+ A = len(p.areaRng)
+ M = len(p.maxDets)
+ precision = -np.ones((T, R, K, A, M)) # -1 for the precision of absent categories
+ recall = -np.ones((T, K, A, M))
+ scores = -np.ones((T, R, K, A, M))
# create dictionary for future indexing
- print('Categories:', p.catIds)
+ print('Categories ids:', p.catIds)
_pe = self._paramsEval
catIds = _pe.catIds if _pe.useCats else [-1]
setK = set(catIds)
@@ -589,58 +605,61 @@ def accumulate(self, p = None):
# different sorting method generates slightly different results.
# mergesort is used to be consistent as Matlab implementation.
inds = np.argsort(-dtScores, kind='mergesort')
+ dtScoresSorted = dtScores[inds]
- dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds]
- dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds]
+ dtm = np.concatenate([e['dtMatches'][:, 0:maxDet] for e in E], axis=1)[:, inds]
+ dtIg = np.concatenate([e['dtIgnore'][:, 0:maxDet] for e in E], axis=1)[:, inds]
gtIg = np.concatenate([e['gtIgnore'] for e in E])
- npig = np.count_nonzero(gtIg==0)
- #print('DTIG', np.sum(np.logical_not(dtIg)), len(dtIg))
- #print('GTIG', np.sum(np.logical_not(gtIg)), len(gtIg))
+ npig = np.count_nonzero(gtIg == 0)
if npig == 0:
continue
- tps = np.logical_and( dtm, np.logical_not(dtIg))
+ tps = np.logical_and( dtm, np.logical_not(dtIg))
fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg))
+
tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
- #print('TP_SUM', tp_sum, 'FP_SUM', fp_sum)
for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
tp = np.array(tp)
fp = np.array(fp)
nd = len(tp)
rc = tp / npig
- pr = tp / (fp+tp+np.spacing(1))
+ pr = tp / (fp + tp + np.spacing(1))
q = np.zeros((R,))
+ ss = np.zeros((R,))
if nd:
- recall[t,k,a,m] = rc[-1]
+ recall[t, k, a, m] = rc[-1]
else:
- recall[t,k,a,m] = 0
+ recall[t, k, a, m] = 0
# numpy is slow without cython optimization for accessing elements
# use python array gets significant speed improvement
pr = pr.tolist(); q = q.tolist()
- for i in range(nd-1, 0, -1):
- if pr[i] > pr[i-1]:
- pr[i-1] = pr[i]
+ for i in range(nd - 1, 0, -1):
+ if pr[i] > pr[i - 1]:
+ pr[i - 1] = pr[i]
inds = np.searchsorted(rc, p.recThrs, side='left')
try:
for ri, pi in enumerate(inds):
q[ri] = pr[pi]
+ ss[ri] = dtScoresSorted[pi]
except:
pass
- precision[t,:,k,a,m] = np.array(q)
- print('Final', np.max(precision), np.min(precision))
+ precision[t, :, k, a, m] = np.array(q)
+ scores[t, :, k, a, m] = np.array(ss)
+ print('Final precisions, max: {:.2f}, min: {:.2f}'.format(np.max(precision), np.min(precision)))
self.eval = {
'params': p,
'counts': [T, R, K, A, M],
'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'precision': precision,
- 'recall': recall,
+ 'recall': recall,
+ 'scores': scores,
}
toc = time.time()
- print('DONE (t={:0.2f}s).'.format( toc-tic))
+ print('DONE (t={:0.2f}s).'.format(toc - tic))
def summarize(self):
'''
@@ -651,12 +670,12 @@ def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
p = self.params
iStr = ' {:<18} {} @[ {}={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
- typeStr = '(AP)' if ap==1 else '(AR)'
+ typeStr = '(AP)' if ap == 1 else '(AR)'
measure = 'IoU'
if self.params.iouType == 'keypoints':
measure = 'OKS'
- elif self.params.iouType =='uv':
- measure = 'OGPS'
+ elif self.params.iouType == 'uv':
+ measure = 'GPS'
iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
if iouThr is None else '{:0.2f}'.format(iouThr)
@@ -667,22 +686,23 @@ def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
s = self.eval['precision']
# IoU
if iouThr is not None:
- t = np.where(np.abs(iouThr - p.iouThrs)<0.001)[0]
+ t = np.where(np.abs(iouThr - p.iouThrs) < 0.001)[0]
s = s[t]
- s = s[:,:,:,aind,mind]
+ s = s[:, :, :, aind, mind]
else:
# dimension of recall: [TxKxAxM]
s = self.eval['recall']
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
- s = s[:,:,aind,mind]
- if len(s[s>-1])==0:
+ s = s[:, :, aind, mind]
+ if len(s[s > -1]) == 0:
mean_s = -1
else:
- mean_s = np.mean(s[s>-1])
+ mean_s = np.mean(s[s > -1])
print(iStr.format(titleStr, typeStr, measure, iouStr, areaRng, maxDets, mean_s))
return mean_s
+
def _summarizeDets():
stats = np.zeros((12,))
stats[0] = _summarize(1)
@@ -698,6 +718,7 @@ def _summarizeDets():
stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
return stats
+
def _summarizeKps():
stats = np.zeros((10,))
stats[0] = _summarize(1, maxDets=20)
@@ -711,6 +732,7 @@ def _summarizeKps():
stats[8] = _summarize(0, maxDets=20, areaRng='medium')
stats[9] = _summarize(0, maxDets=20, areaRng='large')
return stats
+
def _summarizeUvs():
stats = np.zeros((18,))
stats[0] = _summarize(1, maxDets=self.params.maxDets[0])
@@ -732,6 +754,7 @@ def _summarizeUvs():
stats[16] = _summarize(0, maxDets=self.params.maxDets[0], areaRng='medium')
stats[17] = _summarize(0, maxDets=self.params.maxDets[0], areaRng='large')
return stats
+
if not self.eval:
raise Exception('Please run accumulate() first')
iouType = self.params.iouType
@@ -747,74 +770,51 @@ def __str__(self):
self.summarize()
# ================ functions for dense pose ==============================
- def findAllClosestVerts(self, gt, U_points, V_points, Index_points):
- #
+ def findAllClosestVertInds(self, gt, U_points, V_points, Index_points):
I_gt = np.array(gt['dp_I'])
U_gt = np.array(gt['dp_U'])
V_gt = np.array(gt['dp_V'])
- #
- #print(I_gt)
- #
- ClosestVerts = np.ones(Index_points.shape)*-1
- for i in np.arange(24):
- #
- if sum(Index_points == (i+1))>0:
- UVs = np.array( [U_points[Index_points == (i+1)],V_points[Index_points == (i+1)]])
- Current_Part_UVs = self.Part_UVs[i]
- Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i]
- D = ssd.cdist( Current_Part_UVs.transpose(), UVs.transpose()).squeeze()
- ClosestVerts[Index_points == (i+1)] = Current_Part_ClosestVertInds[ np.argmin(D,axis=0) ]
- #
- ClosestVertsGT = np.ones(Index_points.shape)*-1
- for i in np.arange(24):
- if sum(I_gt==(i+1))>0:
- UVs = np.array([
- U_gt[I_gt==(i+1)],
- V_gt[I_gt==(i+1)]
- ])
- Current_Part_UVs = self.Part_UVs[i]
- Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i]
- D = ssd.cdist( Current_Part_UVs.transpose(), UVs.transpose()).squeeze()
- ClosestVertsGT[I_gt==(i+1)] = Current_Part_ClosestVertInds[ np.argmin(D,axis=0) ]
- #
- return ClosestVerts, ClosestVertsGT
-
-
- def getDistances(self, cVertsGT, cVerts):
-
- ClosestVertsTransformed = self.PDIST_transform[cVerts.astype(int)-1]
- ClosestVertsGTTransformed = self.PDIST_transform[cVertsGT.astype(int)-1]
- #
- ClosestVertsTransformed[cVerts<0] = 0
- ClosestVertsGTTransformed[cVertsGT<0] = 0
- #
- cVertsGT = ClosestVertsGTTransformed
- cVerts = ClosestVertsTransformed
- #
- n = 27554
+ # find closest vertex for each estimated point and gt point in each part
+ ClosestVertInds = np.ones(Index_points.shape, dtype=int) * -1
+ ClosestVertIndsGT = np.ones(Index_points.shape, dtype=int) * -1
+ for i in np.arange(1, self.num_parts + 1):
+ Current_Part_UVs = self.Part_UVs[i - 1]
+ Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i - 1]
+ if sum(Index_points == i) > 0:
+ UVs = np.array([U_points[Index_points == i], V_points[Index_points == i]])
+ D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze()
+ ClosestVertInds[Index_points == i] = Current_Part_ClosestVertInds[np.argmin(D, axis=0)]
+ if sum(I_gt == i) > 0:
+ UVs = np.array([U_gt[I_gt == i], V_gt[I_gt == i]])
+ D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze()
+ ClosestVertIndsGT[I_gt == i] = Current_Part_ClosestVertInds[np.argmin(D, axis=0)]
+
+ return ClosestVertInds, ClosestVertIndsGT
+
+
+ def getDistances(self, cVertInds, cVertIndsGT):
+ cVerts = self.PDIST_transform[cVertInds - 1]
+ cVertsGT = self.PDIST_transform[cVertIndsGT - 1]
+ cVerts[cVertInds < 0] = 0
+ cVertsGT[cVertIndsGT < 0] = 0
+ # `n` is the number of points of which the geodesic distances are precomputed
+ # n = 27554
dists = []
+ # loop through each pair of (gt_point, dt_point)
for d in range(len(cVertsGT)):
if cVertsGT[d] > 0:
if cVerts[d] > 0:
i = cVertsGT[d] - 1
j = cVerts[d] - 1
- if j == i:
+ if i == j: # elements on the diagonal are all zeros
dists.append(0)
- elif j > i:
- ccc = i
- i = j
- j = ccc
- i = n-i-1
- j = n-j-1
- k = (n*(n-1)/2) - (n-i)*((n-i)-1)/2 + j - i - 1
- k = ( n*n - n )/2 -k -1
- dists.append(self.Pdist_matrix[int(k)][0])
+ elif i > j:
+ # find the offset to fetch the precomputed geodesic distance
+ k = i * (i - 1) / 2 + j
+ dists.append(self.Pdist_matrix[k][0])
else:
- i= n-i-1
- j= n-j-1
- k = (n*(n-1)/2) - (n-i)*((n-i)-1)/2 + j - i - 1
- k = ( n*n - n )/2 -k -1
- dists.append(self.Pdist_matrix[int(k)][0])
+ k = j * (j - 1) / 2 + i
+ dists.append(self.Pdist_matrix[k][0])
else:
dists.append(np.inf)
return np.array(dists).squeeze()
diff --git a/detectron/datasets/json_dataset.py b/detectron/datasets/json_dataset.py
index c6757e5..f65117f 100644
--- a/detectron/datasets/json_dataset.py
+++ b/detectron/datasets/json_dataset.py
@@ -154,7 +154,7 @@ def _prep_roidb_entry(self, entry):
(0, 3, self.num_keypoints), dtype=np.int32
)
if cfg.MODEL.BODY_UV_ON:
- entry['ignore_UV_body'] = np.empty((0), dtype=np.bool)
+ entry['ignore_UV_body'] = np.empty((0), dtype=np.bool)
# entry['Box_image_links_body'] = []
# Remove unwanted fields that come from the json file (if they exist)
for k in ['date_captured', 'url', 'license', 'file_name']:
@@ -200,7 +200,7 @@ def _add_gt_annotations(self, entry):
valid_objs.append(obj)
valid_segms.append(obj['segmentation'])
###
- if 'dp_x' in obj.keys():
+ if 'dp_x' in obj:
valid_dp_x.append(obj['dp_x'])
valid_dp_y.append(obj['dp_y'])
valid_dp_I.append(obj['dp_I'])
@@ -216,7 +216,7 @@ def _add_gt_annotations(self, entry):
valid_dp_masks.append([])
###
num_valid_objs = len(valid_objs)
- ##
+
boxes = np.zeros((num_valid_objs, 4), dtype=entry['boxes'].dtype)
gt_classes = np.zeros((num_valid_objs), dtype=entry['gt_classes'].dtype)
gt_overlaps = np.zeros(
@@ -234,7 +234,7 @@ def _add_gt_annotations(self, entry):
dtype=entry['gt_keypoints'].dtype
)
if cfg.MODEL.BODY_UV_ON:
- ignore_UV_body = np.zeros((num_valid_objs))
+ ignore_UV_body = np.zeros((num_valid_objs), dtype=entry['ignore_UV_body'].dtype)
#Box_image_body = [None]*num_valid_objs
im_has_visible_keypoints = False
diff --git a/detectron/datasets/json_dataset_evaluator.py b/detectron/datasets/json_dataset_evaluator.py
index a950973..4ff94df 100644
--- a/detectron/datasets/json_dataset_evaluator.py
+++ b/detectron/datasets/json_dataset_evaluator.py
@@ -213,9 +213,9 @@ def _get_thr_ind(coco_eval, thr):
precision = coco_eval.eval['precision'][ind_lo:(ind_hi + 1), :, :, 0, 2]
ap_default = np.mean(precision[precision > -1])
logger.info(
- '~~~~ Mean and per-category AP @ IoU=[{:.2f},{:.2f}] ~~~~'.format(
+ '~~~~ Mean and per-category AP @ IoU=[{:.2f}, {:.2f}] ~~~~'.format(
IoU_lo_thresh, IoU_hi_thresh))
- logger.info('{:.1f}'.format(100 * ap_default))
+ logger.info('mAP: {:.1f}'.format(100 * ap_default))
for cls_ind, cls in enumerate(json_dataset.classes):
if cls == '__background__':
continue
@@ -223,7 +223,7 @@ def _get_thr_ind(coco_eval, thr):
precision = coco_eval.eval['precision'][
ind_lo:(ind_hi + 1), :, cls_ind - 1, 0, 2]
ap = np.mean(precision[precision > -1])
- logger.info('{:.1f}'.format(100 * ap))
+ logger.info("class '{}' AP: {:.1f}".format(cls, 100 * ap))
logger.info('~~~~ Summary metrics ~~~~')
coco_eval.summarize()
@@ -439,20 +439,17 @@ def evaluate_body_uv(
if use_salt:
res_file += '_{}'.format(str(uuid.uuid4()))
res_file += '.pkl'
- results = _write_coco_body_uv_results_file(
+ _write_coco_body_uv_results_file(
json_dataset, all_boxes, all_bodys, res_file
)
# Only do evaluation on non-test sets (annotations are undisclosed on test)
if json_dataset.name.find('test') == -1:
- # See comment in _write_coco_body_uv_results_file
- #coco_eval = _do_body_uv_eval(json_dataset, res_file, output_dir)
- coco_eval = _do_body_uv_eval(json_dataset, results, output_dir)
+ coco_eval = _do_body_uv_eval(json_dataset, res_file, output_dir)
else:
coco_eval = None
- # See comment in _write_coco_body_uv_results_file
# Optionally cleanup results json file
- #if cleanup:
- # os.remove(res_file)
+ if cleanup:
+ os.remove(res_file)
return coco_eval
@@ -460,7 +457,7 @@ def _write_coco_body_uv_results_file(
json_dataset, all_boxes, all_bodys, res_file
):
results = []
- for cls_ind,cls in enumerate(json_dataset.classes):
+ for cls_ind, cls in enumerate(json_dataset.classes):
if cls == '__background__':
continue
if cls_ind >= len(all_bodys):
@@ -473,20 +470,12 @@ def _write_coco_body_uv_results_file(
json_dataset, all_boxes[cls_ind], all_bodys[cls_ind], cat_id))
# Body UV results are stored in 3xHxW ndarray format,
# which is not json serializable
- #logger.info(
- # 'Writing body uv results json to: {}'.format(
- # os.path.abspath(res_file)))
- #with open(res_file, 'w') as fid:
- # json.dump(results, fid)
- #
logger.info(
'Writing body uv results pkl to: {}'.format(
os.path.abspath(res_file)))
-
with open(res_file, 'wb') as f:
pickle.dump(results, f, pickle.HIGHEST_PROTOCOL)
- #logger.info('Not writing body uv resuts json')
- return res_file
+ # logger.info('Not writing body uv resuts json')
def _coco_body_uv_results_one_category(json_dataset, boxes, body_uvs, cat_id):
@@ -502,6 +491,9 @@ def _coco_body_uv_results_one_category(json_dataset, boxes, body_uvs, cat_id):
uv_dets = body_uvs[i]
box_dets = boxes[i].astype(np.float)
scores = box_dets[:, -1]
+ # Convert the uv fields to uint8.
+ for uv in uv_dets:
+ uv[1:3, :, :] = uv[1:3, :, :] * 255
# Don't use xyxy_to_xywh function for consistency with the original imp
# Instead, cast to ints and don't add 1 when computing ws and hs
# xywh_box_dets = box_utils.xyxy_to_xywh(box_dets[:, 0:4])
@@ -509,11 +501,6 @@ def _coco_body_uv_results_one_category(json_dataset, boxes, body_uvs, cat_id):
# ys = xywh_box_dets[:, 1]
# ws = xywh_box_dets[:, 2]
# hs = xywh_box_dets[:, 3]
-
- # Convert the uv fields to uint8.
- for uv in uv_dets:
- uv[1:3,:,:] = uv[1:3,:,:]*255
- ###
xs = box_dets[:, 0]
ys = box_dets[:, 1]
ws = (box_dets[:, 2] - xs).astype(np.int)
@@ -533,17 +520,18 @@ def _do_body_uv_eval(json_dataset, res_file, output_dir):
imgIds = json_dataset.COCO.getImgIds()
imgIds.sort()
with open(res_file, 'rb') as f:
- res=pickle.load(f)
+ res = pickle.load(f)
coco_dt = json_dataset.COCO.loadRes(res)
# Non-standard params used by the modified COCO API version
# from the DensePose fork
+ # global normalization factor used in per-instance evaluation
test_sigma = 0.255
coco_eval = denseposeCOCOeval(json_dataset.COCO, coco_dt, ann_type, test_sigma)
coco_eval.params.imgIds = imgIds
coco_eval.evaluate()
coco_eval.accumulate()
- #eval_file = os.path.join(output_dir, 'body_uv_results.pkl')
- #save_object(coco_eval, eval_file)
- #logger.info('Wrote json eval results to: {}'.format(eval_file))
+ # eval_file = os.path.join(output_dir, 'body_uv_results.pkl')
+ # save_object(coco_eval, eval_file)
+ # logger.info('Wrote json eval results to: {}'.format(eval_file))
coco_eval.summarize()
return coco_eval
diff --git a/detectron/datasets/roidb.py b/detectron/datasets/roidb.py
index e37e3d1..83382f2 100644
--- a/detectron/datasets/roidb.py
+++ b/detectron/datasets/roidb.py
@@ -121,7 +121,7 @@ def is_valid(entry):
if cfg.MODEL.BODY_UV_ON and cfg.BODY_UV_RCNN.BODY_UV_IMS:
# Exclude images with no body uv
valid = valid and entry['has_body_uv']
- return valid
+ return valid
num = len(roidb)
filtered_roidb = [entry for entry in roidb if is_valid(entry)]
diff --git a/detectron/datasets/task_evaluation.py b/detectron/datasets/task_evaluation.py
index bdae2e6..0d076c7 100644
--- a/detectron/datasets/task_evaluation.py
+++ b/detectron/datasets/task_evaluation.py
@@ -47,7 +47,7 @@ def evaluate_all(
output_dir, use_matlab=False
):
"""Evaluate "all" tasks, where "all" includes box detection, instance
- segmentation, and keypoint detection.
+ segmentation, keypoint detection, and human dense pose estimation.
"""
all_results = evaluate_boxes(
dataset, all_boxes, output_dir, use_matlab=use_matlab
diff --git a/detectron/modeling/body_uv_rcnn_heads.py b/detectron/modeling/body_uv_rcnn_heads.py
index dd67f43..f95aef5 100644
--- a/detectron/modeling/body_uv_rcnn_heads.py
+++ b/detectron/modeling/body_uv_rcnn_heads.py
@@ -3,7 +3,22 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-#
+##############################################################################
+
+"""Various network "heads" for dense human pose estimation in DensePose.
+
+The design is as follows:
+
+... -> RoI ----\ /-> mask output -> cls loss
+ -> RoIFeatureXform -> body UV head -> patch output -> cls loss
+... -> Feature / \-> UV output -> reg loss
+ Map
+
+The body UV head produces a feature representation of the RoI for the purpose
+of dense semantic mask prediction, body surface patch prediction and body UV
+coordinates regression. The body UV output module converts the feature
+representation into heatmaps for dense mask, patch index and UV coordinates.
+"""
from __future__ import absolute_import
from __future__ import division
@@ -11,142 +26,123 @@
from __future__ import unicode_literals
from caffe2.python import core
-
from detectron.core.config import cfg
-
+from detectron.utils.c2 import const_fill
import detectron.modeling.ResNet as ResNet
import detectron.utils.blob as blob_utils
# ---------------------------------------------------------------------------- #
-# Body UV heads
+# Body UV outputs and losses
# ---------------------------------------------------------------------------- #
-def add_body_uv_outputs(model, blob_in, dim, pref=''):
- ####
- model.ConvTranspose(blob_in, 'AnnIndex_lowres'+pref, dim, 15,cfg.BODY_UV_RCNN.DECONV_KERNEL, pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1), stride=2, weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}), bias_init=('ConstantFill', {'value': 0.}))
- ####
- model.ConvTranspose(blob_in, 'Index_UV_lowres'+pref, dim, cfg.BODY_UV_RCNN.NUM_PATCHES+1,cfg.BODY_UV_RCNN.DECONV_KERNEL, pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1), stride=2, weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}), bias_init=('ConstantFill', {'value': 0.}))
- ####
- model.ConvTranspose(
- blob_in, 'U_lowres'+pref, dim, (cfg.BODY_UV_RCNN.NUM_PATCHES+1),
- cfg.BODY_UV_RCNN.DECONV_KERNEL,
- pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1),
- stride=2,
- weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}),
- bias_init=('ConstantFill', {'value': 0.}))
- #####
- model.ConvTranspose(
- blob_in, 'V_lowres'+pref, dim, cfg.BODY_UV_RCNN.NUM_PATCHES+1,
+def add_body_uv_outputs(model, blob_in, dim):
+ """Add DensePose body UV specific outputs: heatmaps of dense mask, patch index
+ and patch-specific UV coordinates. All dense masks are mapped to labels in
+ [0, ... S] for S semantically meaningful body parts.
+ """
+ # Apply ConvTranspose to the feature representation; results in 2x upsampling
+ for name in ['AnnIndex', 'Index_UV', 'U', 'V']:
+ if name == 'AnnIndex':
+ dim_out = cfg.BODY_UV_RCNN.NUM_SEMANTIC_PARTS + 1
+ else:
+ dim_out = cfg.BODY_UV_RCNN.NUM_PATCHES + 1
+ model.ConvTranspose(
+ blob_in,
+ name + '_lowres',
+ dim,
+ dim_out,
cfg.BODY_UV_RCNN.DECONV_KERNEL,
pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1),
stride=2,
weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}),
- bias_init=('ConstantFill', {'value': 0.}))
- ####
- blob_Ann_Index = model.BilinearInterpolation('AnnIndex_lowres'+pref, 'AnnIndex'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE)
- blob_Index = model.BilinearInterpolation('Index_UV_lowres'+pref, 'Index_UV'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE)
- blob_U = model.BilinearInterpolation('U_lowres'+pref, 'U_estimated'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE)
- blob_V = model.BilinearInterpolation('V_lowres'+pref, 'V_estimated'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE)
- ###
- return blob_U,blob_V,blob_Index,blob_Ann_Index
-
-
-def add_body_uv_losses(model, pref=''):
-
- ## Reshape for GT blobs.
- model.net.Reshape( ['body_uv_X_points'], ['X_points_reshaped'+pref, 'X_points_shape'+pref], shape=( -1 ,1 ) )
- model.net.Reshape( ['body_uv_Y_points'], ['Y_points_reshaped'+pref, 'Y_points_shape'+pref], shape=( -1 ,1 ) )
- model.net.Reshape( ['body_uv_I_points'], ['I_points_reshaped'+pref, 'I_points_shape'+pref], shape=( -1 ,1 ) )
- model.net.Reshape( ['body_uv_Ind_points'], ['Ind_points_reshaped'+pref, 'Ind_points_shape'+pref], shape=( -1 ,1 ) )
- ## Concat Ind,x,y to get Coordinates blob.
- model.net.Concat( ['Ind_points_reshaped'+pref,'X_points_reshaped'+pref, \
- 'Y_points_reshaped'+pref],['Coordinates'+pref,'Coordinate_Shapes'+pref ], axis = 1 )
- ##
- ### Now reshape UV blobs, such that they are 1x1x(196*NumSamples)xNUM_PATCHES
- ## U blob to
- ##
- model.net.Reshape(['body_uv_U_points'], \
- ['U_points_reshaped'+pref, 'U_points_old_shape'+pref],\
- shape=(-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1,196))
- model.net.Transpose(['U_points_reshaped'+pref] ,['U_points_reshaped_transpose'+pref],axes=(0,2,1) )
- model.net.Reshape(['U_points_reshaped_transpose'+pref], \
- ['U_points'+pref, 'U_points_old_shape2'+pref], \
- shape=(1,1,-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1))
- ## V blob
- ##
- model.net.Reshape(['body_uv_V_points'], \
- ['V_points_reshaped'+pref, 'V_points_old_shape'+pref],\
- shape=(-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1,196))
- model.net.Transpose(['V_points_reshaped'+pref] ,['V_points_reshaped_transpose'+pref],axes=(0,2,1) )
- model.net.Reshape(['V_points_reshaped_transpose'+pref], \
- ['V_points'+pref, 'V_points_old_shape2'+pref], \
- shape=(1,1,-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1))
- ###
- ## UV weights blob
- ##
- model.net.Reshape(['body_uv_point_weights'], \
- ['Uv_point_weights_reshaped'+pref, 'Uv_point_weights_old_shape'+pref],\
- shape=(-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1,196))
- model.net.Transpose(['Uv_point_weights_reshaped'+pref] ,['Uv_point_weights_reshaped_transpose'+pref],axes=(0,2,1) )
- model.net.Reshape(['Uv_point_weights_reshaped_transpose'+pref], \
- ['Uv_point_weights'+pref, 'Uv_point_weights_old_shape2'+pref], \
- shape=(1,1,-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1))
-
- #####################
- ### Pool IUV for points via bilinear interpolation.
- model.PoolPointsInterp(['U_estimated','Coordinates'+pref], ['interp_U'+pref])
- model.PoolPointsInterp(['V_estimated','Coordinates'+pref], ['interp_V'+pref])
- model.PoolPointsInterp(['Index_UV'+pref,'Coordinates'+pref], ['interp_Index_UV'+pref])
-
- ## Reshape interpolated UV coordinates to apply the loss.
-
- model.net.Reshape(['interp_U'+pref], \
- ['interp_U_reshaped'+pref, 'interp_U_shape'+pref],\
- shape=(1, 1, -1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1))
-
- model.net.Reshape(['interp_V'+pref], \
- ['interp_V_reshaped'+pref, 'interp_V_shape'+pref],\
- shape=(1, 1, -1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1))
- ###
-
- ### Do the actual labels here !!!!
- model.net.Reshape( ['body_uv_ann_labels'], \
- ['body_uv_ann_labels_reshaped' +pref, 'body_uv_ann_labels_old_shape'+pref], \
- shape=(-1, cfg.BODY_UV_RCNN.HEATMAP_SIZE , cfg.BODY_UV_RCNN.HEATMAP_SIZE))
-
- model.net.Reshape( ['body_uv_ann_weights'], \
- ['body_uv_ann_weights_reshaped' +pref, 'body_uv_ann_weights_old_shape'+pref], \
- shape=( -1 , cfg.BODY_UV_RCNN.HEATMAP_SIZE , cfg.BODY_UV_RCNN.HEATMAP_SIZE))
- ###
- model.net.Cast( ['I_points_reshaped'+pref], ['I_points_reshaped_int'+pref], to=core.DataType.INT32)
- ### Now add the actual losses
- ## The mask segmentation loss (dense)
- probs_seg_AnnIndex, loss_seg_AnnIndex = model.net.SpatialSoftmaxWithLoss( \
- ['AnnIndex'+pref, 'body_uv_ann_labels_reshaped'+pref,'body_uv_ann_weights_reshaped'+pref],\
- ['probs_seg_AnnIndex'+pref,'loss_seg_AnnIndex'+pref], \
- scale=cfg.BODY_UV_RCNN.INDEX_WEIGHTS / cfg.NUM_GPUS)
- ## Point Patch Index Loss.
- probs_IndexUVPoints, loss_IndexUVPoints = model.net.SoftmaxWithLoss(\
- ['interp_Index_UV'+pref,'I_points_reshaped_int'+pref],\
- ['probs_IndexUVPoints'+pref,'loss_IndexUVPoints'+pref], \
- scale=cfg.BODY_UV_RCNN.PART_WEIGHTS / cfg.NUM_GPUS, spatial=0)
- ## U and V point losses.
- loss_Upoints = model.net.SmoothL1Loss( \
- ['interp_U_reshaped'+pref, 'U_points'+pref, \
- 'Uv_point_weights'+pref, 'Uv_point_weights'+pref], \
- 'loss_Upoints'+pref, \
- scale=cfg.BODY_UV_RCNN.POINT_REGRESSION_WEIGHTS / cfg.NUM_GPUS)
-
- loss_Vpoints = model.net.SmoothL1Loss( \
- ['interp_V_reshaped'+pref, 'V_points'+pref, \
- 'Uv_point_weights'+pref, 'Uv_point_weights'+pref], \
- 'loss_Vpoints'+pref, scale=cfg.BODY_UV_RCNN.POINT_REGRESSION_WEIGHTS / cfg.NUM_GPUS)
- ## Add the losses.
- loss_gradients = blob_utils.get_loss_gradients(model, \
- [ loss_Upoints, loss_Vpoints, loss_seg_AnnIndex, loss_IndexUVPoints])
- model.losses = list(set(model.losses + \
- ['loss_Upoints'+pref , 'loss_Vpoints'+pref , \
- 'loss_seg_AnnIndex'+pref ,'loss_IndexUVPoints'+pref]))
+ bias_init=const_fill(0.0)
+ )
+ # Increase heatmap output size via bilinear upsampling
+ blob_outputs = []
+ for name in ['AnnIndex', 'Index_UV', 'U', 'V']:
+ blob_outputs.append(
+ model.BilinearInterpolation(
+ name + '_lowres',
+ name + '_estimated' if name in ['U', 'V'] else name,
+ cfg.BODY_UV_RCNN.NUM_PATCHES + 1,
+ cfg.BODY_UV_RCNN.NUM_PATCHES + 1,
+ cfg.BODY_UV_RCNN.UP_SCALE
+ )
+ )
+
+ return blob_outputs
+
+
+def add_body_uv_losses(model):
+ """Add DensePose body UV specific losses."""
+ # Pool estimated IUV points via bilinear interpolation.
+ for name in ['U', 'V', 'Index_UV']:
+ model.PoolPointsInterp(
+ [
+ name + '_estimated' if name in ['U', 'V'] else name,
+ 'body_uv_coords_xy'
+ ],
+ ['interp_' + name]
+ )
+
+ # Compute spatial softmax normalized probabilities, after which
+ # cross-entropy loss is computed for semantic parts classification.
+ probs_AnnIndex, loss_AnnIndex = model.net.SpatialSoftmaxWithLoss(
+ [
+ 'AnnIndex',
+ 'body_uv_parts', 'body_uv_parts_weights'
+ ],
+ ['probs_AnnIndex', 'loss_AnnIndex'],
+ scale=model.GetLossScale() * cfg.BODY_UV_RCNN.INDEX_WEIGHTS
+ )
+ # Softmax loss for surface patch classification.
+ probs_I_points, loss_I_points = model.net.SoftmaxWithLoss(
+ ['interp_Index_UV', 'body_uv_I_points'],
+ ['probs_I_points', 'loss_I_points'],
+ scale=model.GetLossScale() * cfg.BODY_UV_RCNN.PART_WEIGHTS,
+ spatial=0
+ )
+ ## Smooth L1 loss for each patch-specific UV coordinates regression.
+ # Reshape U,V blobs of both interpolated and ground-truth to compute
+ # summarized (instead of averaged) SmoothL1Loss.
+ loss_UV = list()
+ model.net.Reshape(
+ ['body_uv_point_weights'],
+ ['UV_point_weights', 'body_uv_point_weights_shape'],
+ shape=(1, -1, cfg.BODY_UV_RCNN.NUM_PATCHES + 1)
+ )
+ for name in ['U', 'V']:
+ # Reshape U/V coordinates of both interpolated points and ground-truth
+ # points from (#points, #patches) to (1, #points, #patches).
+ model.net.Reshape(
+ ['body_uv_' + name + '_points'],
+ [name + '_points', 'body_uv_' + name + '_points_shape'],
+ shape=(1, -1, cfg.BODY_UV_RCNN.NUM_PATCHES + 1)
+ )
+ model.net.Reshape(
+ ['interp_' + name],
+ ['interp_' + name + '_reshaped', 'interp_' + name + 'shape'],
+ shape=(1, -1, cfg.BODY_UV_RCNN.NUM_PATCHES + 1)
+ )
+ # Compute summarized SmoothL1Loss of all points.
+ loss_UV.append(
+ model.net.SmoothL1Loss(
+ [
+ 'interp_' + name + '_reshaped', name + '_points',
+ 'UV_point_weights', 'UV_point_weights'
+ ],
+ 'loss_' + name + '_points',
+ scale=model.GetLossScale() * cfg.BODY_UV_RCNN.POINT_REGRESSION_WEIGHTS
+ )
+ )
+ # Add all losses to compute gradients
+ loss_gradients = blob_utils.get_loss_gradients(
+ model, [loss_AnnIndex, loss_I_points] + loss_UV
+ )
+ # Update model training losses
+ model.AddLosses(
+ ['loss_' + name for name in ['AnnIndex', 'I_points', 'U_points', 'V_points']]
+ )
return loss_gradients
@@ -155,17 +151,17 @@ def add_body_uv_losses(model, pref=''):
# Body UV heads
# ---------------------------------------------------------------------------- #
-def add_ResNet_roi_conv5_head_for_bodyUV(
- model, blob_in, dim_in, spatial_scale
-):
+def add_ResNet_roi_conv5_head_for_bodyUV(model, blob_in, dim_in, spatial_scale):
"""Add a ResNet "conv5" / "stage5" head for body UV prediction."""
model.RoIFeatureTransform(
- blob_in, '_[body_uv]_pool5',
+ blob_in,
+ '_[body_uv]_pool5',
blob_rois='body_uv_rois',
method=cfg.BODY_UV_RCNN.ROI_XFORM_METHOD,
resolution=cfg.BODY_UV_RCNN.ROI_XFORM_RESOLUTION,
sampling_ratio=cfg.BODY_UV_RCNN.ROI_XFORM_SAMPLING_RATIO,
- spatial_scale=spatial_scale)
+ spatial_scale=spatial_scale
+ )
# Using the prefix '_[body_uv]_' to 'res5' enables initializing the head's
# parameters using pretrained 'res5' parameters if given (see
# utils.net.initialize_from_weights_file)
@@ -184,7 +180,7 @@ def add_ResNet_roi_conv5_head_for_bodyUV(
def add_roi_body_uv_head_v1convX(model, blob_in, dim_in, spatial_scale):
- """v1convX design: X * (conv)."""
+ """Add a DensePose body UV head. v1convX design: X * (conv)."""
hidden_dim = cfg.BODY_UV_RCNN.CONV_HEAD_DIM
kernel_size = cfg.BODY_UV_RCNN.CONV_HEAD_KERNEL
pad_size = kernel_size // 2
@@ -208,7 +204,7 @@ def add_roi_body_uv_head_v1convX(model, blob_in, dim_in, spatial_scale):
stride=1,
pad=pad_size,
weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.01}),
- bias_init=('ConstantFill', {'value': 0.})
+ bias_init=const_fill(0.0)
)
current = model.Relu(current, current)
dim_in = hidden_dim
diff --git a/detectron/modeling/model_builder.py b/detectron/modeling/model_builder.py
index 35f9f2c..0eb3899 100644
--- a/detectron/modeling/model_builder.py
+++ b/detectron/modeling/model_builder.py
@@ -329,7 +329,7 @@ def _add_roi_body_uv_head(
model, add_roi_body_uv_head_func, blob_in, dim_in, spatial_scale_in
):
"""Add a body UV prediction head to the model."""
- # Capture model graph before adding the mask head
+ # Capture model graph before adding the body UV head
bbox_net = copy.deepcopy(model.net.Proto())
# Add the body UV head
blob_body_uv_head, dim_body_uv_head = add_roi_body_uv_head_func(
@@ -343,7 +343,7 @@ def _add_roi_body_uv_head(
if not model.train: # == inference
# Inference uses a cascade of box predictions, then body uv predictions
# This requires separate nets for box and body uv prediction.
- # So we extract the keypoint prediction net, store it as its own
+ # So we extract the body uv prediction net, store it as its own
# network, then restore model.net to be the bbox-only network
model.body_uv_net, body_uv_blob_out = c2_utils.SuffixNet(
'body_uv_net', model.net, len(bbox_net.op), blobs_body_uv
diff --git a/detectron/ops/pool_points_interp.cc b/detectron/ops/pool_points_interp.cc
index 0bbc682..54c6245 100644
--- a/detectron/ops/pool_points_interp.cc
+++ b/detectron/ops/pool_points_interp.cc
@@ -1,33 +1,66 @@
/**
-# Copyright (c) Facebook, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the license found in the
+ * LICENSE file in the root directory of this source tree.
*/
-
#include "pool_points_interp.h"
namespace caffe2 {
-//namespace {
REGISTER_CPU_OPERATOR(PoolPointsInterp,
PoolPointsInterpOp);
REGISTER_CPU_OPERATOR(PoolPointsInterpGradient,
PoolPointsInterpGradientOp);
-// Input: X, points; Output: Y
-OPERATOR_SCHEMA(PoolPointsInterp).NumInputs(2).NumOutputs(1);
-// Input: X, points, dY (aka "gradOutput");
-// Output: dX (aka "gradInput")
-OPERATOR_SCHEMA(PoolPointsInterpGradient).NumInputs(3).NumOutputs(1);
+OPERATOR_SCHEMA(PoolPointsInterp)
+ .NumInputs(2)
+ .NumOutputs(1)
+ .Input(
+ 0,
+ "X",
+ "4D feature/heat map input of shape (N, C, H, W).")
+ .Input(
+ 1,
+ "coords",
+ "2D input of shape (P, 2) specifying P points with 2 columns "
+ "representing 2D coordinates on the image (x, y). The "
+ "coordinates have been converted to in the coordinate system of X.")
+ .Output(
+ 0,
+ "Y",
+ "2D output of shape (P, K). The r-th batch element is a "
+ "pooled/interpolated index or UV coordinate corresponding "
+ "to the r-th point over all K patches (including background).");
+
+OPERATOR_SCHEMA(PoolPointsInterpGradient)
+ .NumInputs(3)
+ .NumOutputs(1)
+ .Input(
+ 0,
+ "X",
+ "See PoolPointsInterp.")
+ .Input(
+ 1,
+ "coords",
+ "See PoolPointsInterp.")
+ .Input(
+ 2,
+ "dY",
+ "Gradient of forward output 0 (Y)")
+ .Output(
+ 0,
+ "dX",
+ "Gradient of forward input 0 (X)");
class GetPoolPointsInterpGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector GetGradientDefs() override {
return SingleGradientDef(
- "PoolPointsInterpGradient", "",
+ "PoolPointsInterpGradient",
+ "",
vector{I(0), I(1), GO(0)},
vector{GI(0)});
}
@@ -35,5 +68,4 @@ class GetPoolPointsInterpGradient : public GradientMakerBase {
REGISTER_GRADIENT(PoolPointsInterp, GetPoolPointsInterpGradient);
-//} // namespace
} // namespace caffe2
diff --git a/detectron/ops/pool_points_interp.cu b/detectron/ops/pool_points_interp.cu
index 6286e2c..7a9f755 100644
--- a/detectron/ops/pool_points_interp.cu
+++ b/detectron/ops/pool_points_interp.cu
@@ -1,9 +1,9 @@
/**
-# Copyright (c) Facebook, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the license found in the
+ * LICENSE file in the root directory of this source tree.
*/
#include
@@ -28,24 +28,26 @@ float gpu_atomic_add(const float val, float* address) {
template
__device__ T bilinear_interpolate(const T* bottom_data,
- const int height, const int width,
- T y, T x,
+ const int height, const int width, T x, T y,
const int index /* index for debug only*/) {
-
// deal with cases that inverse elements are out of feature map boundary
- if (y < -1.0 || y > height || x < -1.0 || x > width) {
- //empty
+ if (x < -1.0 || x > width || y < -1.0 || y > height) {
return 0;
}
- if (y <= 0) y = 0;
if (x <= 0) x = 0;
+ if (y <= 0) y = 0;
- int y_low = (int) y;
int x_low = (int) x;
- int y_high;
- int x_high;
+ int y_low = (int) y;
+ int x_high, y_high;
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T) x_low;
+ } else {
+ x_high = x_low + 1;
+ }
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T) y_low;
@@ -53,82 +55,62 @@ __device__ T bilinear_interpolate(const T* bottom_data,
y_high = y_low + 1;
}
- if (x_low >= width - 1) {
- x_high = x_low = width - 1;
- x = (T) x_low;
- } else {
- x_high = x_low + 1;
- }
+ // lambdas in X, Y axes
+ T lx = x - x_low, ly = y - y_low;
+ T hx = 1. - lx, hy = 1. - ly;
- T ly = y - y_low;
- T lx = x - x_low;
- T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
- T v1 = bottom_data[y_low * width + x_low];
- T v2 = bottom_data[y_low * width + x_high];
- T v3 = bottom_data[y_high * width + x_low];
- T v4 = bottom_data[y_high * width + x_high];
+ T v1 = bottom_data[y_low * width + x_low]; // top-left point
+ T v2 = bottom_data[y_low * width + x_high]; // top-right point
+ T v3 = bottom_data[y_high * width + x_low]; // bottom-left point
+ T v4 = bottom_data[y_high * width + x_high]; // bottom-right point
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
- T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
-
- return val;
+ return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
template
-__global__ void PointWarpForward(const int nthreads, const T* bottom_data,
- const T spatial_scale, const int channels,
- const int height, const int width,
- const T* bottom_rois, T* top_data) {
+__global__ void PoolPointsInterpForward(const int nthreads, const T* bottom_data,
+ const T spatial_scale, const int channels, const int height, const int width,
+ const T* coords, T* top_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c) is an element in the pooled/interpolated output
int c = index % channels;
int n = index / channels;
- //
- const T* offset_bottom_rois = bottom_rois + n * 3;
-
- int roi_batch_ind = n/196; // Should be original !!
- //
- T X_point = offset_bottom_rois[1] * spatial_scale;
- T Y_point = offset_bottom_rois[2] * spatial_scale;
-
-
+ const T* offset_coords = coords + n * 2;
+ // Get index of current fg roi among all fg rois in a minibatch
+ int roi_batch_ind = n / 196;
+ // Get spatial coordinate (x, y)
+ T x = offset_coords[0] * spatial_scale;
+ T y = offset_coords[1] * spatial_scale;
const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
-
- T val = bilinear_interpolate(offset_bottom_data, height, width, Y_point, X_point, index);
- top_data[index] = val;
+ // Compute interpolated value
+ top_data[index] = bilinear_interpolate(
+ offset_bottom_data, height, width, x, y, index);
}
}
template
__device__ void bilinear_interpolate_gradient(
- const int height, const int width,
- T y, T x,
+ const int height, const int width, T x, T y,
T & w1, T & w2, T & w3, T & w4,
int & x_low, int & x_high, int & y_low, int & y_high,
const int index /* index for debug only*/) {
-
// deal with cases that inverse elements are out of feature map boundary
- if (y < -1.0 || y > height || x < -1.0 || x > width) {
- //empty
+ if (x < -1.0 || x > width || y < -1.0 || y > height) {
+ // empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
- if (y <= 0) y = 0;
if (x <= 0) x = 0;
+ if (y <= 0) y = 0;
- y_low = (int) y;
x_low = (int) x;
-
- if (y_low >= height - 1) {
- y_high = y_low = height - 1;
- y = (T) y_low;
- } else {
- y_high = y_low + 1;
- }
+ y_low = (int) y;
if (x_low >= width - 1) {
x_high = x_low = width - 1;
@@ -136,11 +118,15 @@ __device__ void bilinear_interpolate_gradient(
} else {
x_high = x_low + 1;
}
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T) y_low;
+ } else {
+ y_high = y_low + 1;
+ }
- T ly = y - y_low;
- T lx = x - x_low;
- T hy = 1. - ly, hx = 1. - lx;
-
+ T lx = x - x_low, ly = y - y_low;
+ T hx = 1. - lx, hy = 1. - ly;
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
@@ -148,30 +134,23 @@ __device__ void bilinear_interpolate_gradient(
}
template
-__global__ void PointWarpBackwardFeature(const int nthreads, const T* top_diff,
- const int num_rois, const T spatial_scale,
- const int channels, const int height, const int width,
-
- T* bottom_diff,
- const T* bottom_rois) {
+__global__ void PoolPointsInterpBackward(const int nthreads, const T* top_diff,
+ const int num_rois, const T spatial_scale, const int channels,
+ const int height, const int width, T* bottom_diff, const T* coords) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int c = index % channels;
- int n = index / channels;
-
- const T* offset_bottom_rois = bottom_rois + n * 3;
- // int roi_batch_ind = offset_bottom_rois[0];
- int roi_batch_ind = n/196; // Should be original !!
+ int c = index % channels;
+ int n = index / channels;
- T X_point = offset_bottom_rois[1] * spatial_scale;
- T Y_point = offset_bottom_rois[2] * spatial_scale;
+ const T* offset_coords = coords + n * 2;
+ int roi_batch_ind = n / 196;
+ T x = offset_coords[0] * spatial_scale;
+ T y = offset_coords[1] * spatial_scale;
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
- bilinear_interpolate_gradient(height, width, Y_point, X_point,
- w1, w2, w3, w4,
- x_low, x_high, y_low, y_high,
- index);
+ bilinear_interpolate_gradient(height, width, x, y,
+ w1, w2, w3, w4, x_low, x_high, y_low, y_high, index);
T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
//
@@ -193,35 +172,30 @@ __global__ void PointWarpBackwardFeature(const int nthreads, const T* top_diff,
} // if
} // CUDA_1D_KERNEL_LOOP
-} // ROIWarpBackward
-
+} // PoolPointsInterpBackward
} // namespace
template<>
bool PoolPointsInterpOp::RunOnDevice() {
auto& X = Input(0); // Input data to pool
- auto& R = Input(1); // RoIs
- auto* Y = Output(0); // RoI pooled data
+ auto& R = Input(1); // Spatial coordinates of all points within RoIs
if (R.size() == 0) {
// Handle empty rois
- Y->Resize(0, X.dim32(1));
- // The following mutable_data calls are needed to allocate the tensors
- Y->mutable_data();
+ std::vector sizes = {0, X.dim32(1)};
+ /* auto* Y = */ Output(0, sizes, at::dtype());
return true;
}
- Y->Resize(R.dim32(0), X.dim32(1));
+ auto* Y = Output(0, {R.dim32(0), X.dim32(1)}, at::dtype()); // Pooled interpolated data
int output_size = Y->size();
- PointWarpForward<<>>(
+ PoolPointsInterpForward<<>>(
output_size, X.data(), spatial_scale_,
X.dim32(1), X.dim32(2), X.dim32(3),
- R.data(),
- Y->mutable_data()
- );
+ R.data(), Y->mutable_data());
return true;
}
@@ -232,59 +206,46 @@ __global__ void SetKernel(const int N, const T alpha, T* Y) {
Y[i] = alpha;
}
}
-}
-
+} // namespace
namespace {
-
-
template
__global__ void SetEvenIndsToVal(size_t num_even_inds, T val, T* data) {
CUDA_1D_KERNEL_LOOP(i, num_even_inds) {
data[i << 1] = val;
}
}
-}
-
+} // namespace
-
template<>
bool PoolPointsInterpGradientOp::RunOnDevice() {
auto& X = Input(0); // Input data to pool
- auto& R = Input(1); // RoIs
+ auto& R = Input(1); // 2D Spatial coordinates of all points within RoIs
auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op
// (aka "gradOutput")
- auto* dX = Output(0); // Gradient of net w.r.t. input to "forward" op
- // (aka "gradInput")
-
- dX->ResizeLike(X);
+ auto* dX = Output(
+ 0, X.sizes(), at::dtype()); // Gradient of net w.r.t. input to
+ // "forward" op (aka "gradInput")
- SetKernel
- <<size()),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- dX->size(),
- 0.f,
- dX->mutable_data());
+ SetKernel<<size()),
+ CAFFE_CUDA_NUM_THREADS,
+ 0, context_.cuda_stream()>>>(
+ dX->size(), 0.f, dX->mutable_data());
if (dY.size() > 0) { // Handle possibly empty gradient if there were no rois
- PointWarpBackwardFeature<<>>(
+ PoolPointsInterpBackward<<>>(
dY.size(), dY.data(), R.dim32(0), spatial_scale_,
X.dim32(1), X.dim32(2), X.dim32(3),
- dX->mutable_data(),
- R.data());
+ dX->mutable_data(), R.data());
}
return true;
}
-//namespace {
REGISTER_CUDA_OPERATOR(PoolPointsInterp,
PoolPointsInterpOp);
REGISTER_CUDA_OPERATOR(PoolPointsInterpGradient,
PoolPointsInterpGradientOp);
-//} // namespace
} // namespace caffe2
diff --git a/detectron/ops/pool_points_interp.h b/detectron/ops/pool_points_interp.h
index 367a29c..02d5c65 100644
--- a/detectron/ops/pool_points_interp.h
+++ b/detectron/ops/pool_points_interp.h
@@ -1,12 +1,11 @@
/**
-# Copyright (c) Facebook, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the license found in the
+ * LICENSE file in the root directory of this source tree.
*/
-
#ifndef POOL_POINTS_INTERP_OP_H_
#define POOL_POINTS_INTERP_OP_H_
@@ -21,13 +20,14 @@ class PoolPointsInterpOp final : public Operator {
public:
PoolPointsInterpOp(const OperatorDef& operator_def, Workspace* ws)
: Operator(operator_def, ws),
- spatial_scale_(OperatorBase::GetSingleArgument(
+ spatial_scale_(this->template GetSingleArgument(
"spatial_scale", 1.)) {
DCHECK_GT(spatial_scale_, 0);
}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
+ // No CPU implementation for now
CAFFE_NOT_IMPLEMENTED;
}
@@ -40,13 +40,14 @@ class PoolPointsInterpGradientOp final : public Operator {
public:
PoolPointsInterpGradientOp(const OperatorDef& def, Workspace* ws)
: Operator(def, ws),
- spatial_scale_(OperatorBase::GetSingleArgument(
- "spatial_scale", 1.)){
+ spatial_scale_(this->template GetSingleArgument(
+ "spatial_scale", 1.)) {
DCHECK_GT(spatial_scale_, 0);
}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
+ // No CPU implementation for now
CAFFE_NOT_IMPLEMENTED;
}
@@ -56,4 +57,4 @@ class PoolPointsInterpGradientOp final : public Operator {
} // namespace caffe2
-#endif // PoolPointsInterpOp
+#endif // POOL_POINTS_INTERP_OP_H_
diff --git a/detectron/roi_data/body_uv_rcnn.py b/detectron/roi_data/body_uv_rcnn.py
index 1cb0f0b..27c817a 100644
--- a/detectron/roi_data/body_uv_rcnn.py
+++ b/detectron/roi_data/body_uv_rcnn.py
@@ -3,18 +3,21 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+##############################################################################
+
+"""Construct minibatches for DensePose training. Handles the minibatch blobs
+that are specific to DensePose. Other blobs that are generic to RPN or
+Fast/er R-CNN are handled by their respecitive roi_data modules.
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
-#
-from scipy.io import loadmat
-import copy
+
import cv2
import logging
import numpy as np
-#
from detectron.core.config import cfg
import detectron.utils.blob as blob_utils
@@ -22,191 +25,180 @@
import detectron.utils.segms as segm_utils
import detectron.utils.densepose_methods as dp_utils
-#
-from memory_profiler import profile
-#
-import os
-#
logger = logging.getLogger(__name__)
-#
+
DP = dp_utils.DensePoseMethods()
-#
+
def add_body_uv_rcnn_blobs(blobs, sampled_boxes, roidb, im_scale, batch_idx):
- IsFlipped = roidb['flipped']
+ """Add DensePose specific blobs to the given inputs blobs dictionary."""
M = cfg.BODY_UV_RCNN.HEATMAP_SIZE
- #
+ # Prepare the body UV targets by associating one gt box which contains
+ # body UV annotations to each training roi that has a fg class label.
polys_gt_inds = np.where(roidb['ignore_UV_body'] == 0)[0]
- boxes_from_polys = [roidb['boxes'][i,:] for i in polys_gt_inds]
- if not(boxes_from_polys):
- pass
- else:
- boxes_from_polys = np.vstack(boxes_from_polys)
- boxes_from_polys = np.array(boxes_from_polys)
-
+ boxes_from_polys = roidb['boxes'][polys_gt_inds]
+ # Select foreground RoIs
fg_inds = np.where(blobs['labels_int32'] > 0)[0]
- roi_has_mask = np.zeros( blobs['labels_int32'].shape )
+ roi_has_body_uv = np.zeros_like(blobs['labels_int32'], dtype=np.int32)
- if (bool(boxes_from_polys.any()) & (fg_inds.shape[0] > 0) ):
+ if ((boxes_from_polys.shape[0] > 0) & (fg_inds.shape[0] > 0)):
+ # Find overlap between all foreground RoIs and the gt bounding boxes
+ # containing each body UV annotaion.
rois_fg = sampled_boxes[fg_inds]
- #
- rois_fg.astype(np.float32, copy=False)
- boxes_from_polys.astype(np.float32, copy=False)
- #
overlaps_bbfg_bbpolys = box_utils.bbox_overlaps(
rois_fg.astype(np.float32, copy=False),
- boxes_from_polys.astype(np.float32, copy=False))
+ boxes_from_polys.astype(np.float32, copy=False)
+ )
+ # Select foreground RoIs as those with > 0.7 overlap
fg_polys_value = np.max(overlaps_bbfg_bbpolys, axis=1)
- fg_inds = fg_inds[fg_polys_value>0.7]
-
- if (bool(boxes_from_polys.any()) & (fg_inds.shape[0] > 0) ):
- for jj in fg_inds:
- roi_has_mask[jj] = 1
-
- # Create blobs for densepose supervision.
- ################################################## The mask
- All_labels = blob_utils.zeros((fg_inds.shape[0], M ** 2), int32=True)
- All_Weights = blob_utils.zeros((fg_inds.shape[0], M ** 2), int32=True)
- ################################################# The points
- X_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=False)
- Y_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=False)
- Ind_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=True)
+ fg_inds = fg_inds[fg_polys_value > 0.7]
+
+ if ((boxes_from_polys.shape[0] > 0) & (fg_inds.shape[0] > 0)):
+ roi_has_body_uv[fg_inds] = 1
+ # Create body UV blobs
+ # Dense masks, each mask for a given fg roi is of size M x M.
+ part_inds = blob_utils.zeros((fg_inds.shape[0], M, M), int32=True)
+ # Weights assigned to each target in `part_inds`. By default, all 1's.
+ # part_inds_weights = blob_utils.zeros((fg_inds.shape[0], M, M), int32=True)
+ part_inds_weights = blob_utils.ones((fg_inds.shape[0], M, M), int32=False)
+ # 2D spatial coordinates (on the image). Shape is (#fg_rois, 2) in format
+ # (x, y).
+ coords_xy = blob_utils.zeros((fg_inds.shape[0], 196, 2), int32=False)
+ # 24 patch indices plus a background class
I_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=True)
+ # UV coordinates in each patch
U_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=False)
V_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=False)
- Uv_point_weights = blob_utils.zeros((fg_inds.shape[0], 196), int32=False)
- #################################################
+ # Uv_point_weights = blob_utils.zeros((fg_inds.shape[0], 196), int32=False)
rois_fg = sampled_boxes[fg_inds]
- overlaps_bbfg_bbpolys = box_utils.bbox_overlaps(
- rois_fg.astype(np.float32, copy=False),
- boxes_from_polys.astype(np.float32, copy=False))
+ overlaps_bbfg_bbpolys = overlaps_bbfg_bbpolys[fg_inds]
+ # Map from each fg roi to the index of the gt box with highest overlap
fg_polys_inds = np.argmax(overlaps_bbfg_bbpolys, axis=1)
+ # Add body UV targets for each fg roi
for i in range(rois_fg.shape[0]):
- #
- fg_polys_ind = polys_gt_inds[ fg_polys_inds[i] ]
- #
- Ilabel = segm_utils.GetDensePoseMask( roidb['dp_masks'][ fg_polys_ind ] )
- #
- GT_I = np.array(roidb['dp_I'][ fg_polys_ind ])
- GT_U = np.array(roidb['dp_U'][ fg_polys_ind ])
- GT_V = np.array(roidb['dp_V'][ fg_polys_ind ])
- GT_x = np.array(roidb['dp_x'][ fg_polys_ind ])
- GT_y = np.array(roidb['dp_y'][ fg_polys_ind ])
- GT_weights = np.ones(GT_I.shape).astype(np.float32)
- #
- ## Do the flipping of the densepose annotation !
- if(IsFlipped):
- GT_I,GT_U,GT_V,GT_x,GT_y,Ilabel = DP.get_symmetric_densepose(GT_I,GT_U,GT_V,GT_x,GT_y,Ilabel)
- #
+ fg_polys_ind = fg_polys_inds[i]
+ polys_gt_ind = polys_gt_inds[fg_polys_ind]
+ # RLE encoded dense masks which are of size 256 x 256.
+ # Map all part masks to 14 labels (i.e., indices of semantic body parts).
+ dp_masks = dp_utils.GetDensePoseMask(
+ roidb['dp_masks'][polys_gt_ind], cfg.BODY_UV_RCNN.NUM_SEMANTIC_PARTS
+ )
+ # Surface patch indices of collected points
+ dp_I = np.array(roidb['dp_I'][polys_gt_ind], dtype=np.int32)
+ # UV coordinates of collected points
+ dp_U = np.array(roidb['dp_U'][polys_gt_ind], dtype=np.float32)
+ dp_V = np.array(roidb['dp_V'][polys_gt_ind], dtype=np.float32)
+ # dp_UV_weights = np.ones_like(dp_I).astype(np.float32)
+ # Spatial coordinates on the image which are scaled such that the bbox
+ # size is 256 x 256.
+ dp_x = np.array(roidb['dp_x'][polys_gt_ind], dtype=np.float32)
+ dp_y = np.array(roidb['dp_y'][polys_gt_ind], dtype=np.float32)
+ # Do the flipping of the densepose annotation
+ if roidb['flipped']:
+ dp_I, dp_U, dp_V, dp_x, dp_y, dp_masks = DP.get_symmetric_densepose(
+ dp_I, dp_U, dp_V, dp_x, dp_y, dp_masks
+ )
+
roi_fg = rois_fg[i]
- roi_gt = boxes_from_polys[fg_polys_inds[i],:]
- #
- x1 = roi_fg[0] ; x2 = roi_fg[2]
- y1 = roi_fg[1] ; y2 = roi_fg[3]
- #
- x1_source = roi_gt[0]; x2_source = roi_gt[2]
- y1_source = roi_gt[1]; y2_source = roi_gt[3]
- #
- x_targets = ( np.arange(x1,x2, (x2 - x1)/M ) - x1_source ) * ( 256. / (x2_source-x1_source) )
- y_targets = ( np.arange(y1,y2, (y2 - y1)/M ) - y1_source ) * ( 256. / (y2_source-y1_source) )
- #
- x_targets = x_targets[0:M] ## Strangely sometimes it can be M+1, so make sure size is OK!
- y_targets = y_targets[0:M]
- #
- [X_targets,Y_targets] = np.meshgrid( x_targets, y_targets )
- New_Index = cv2.remap(Ilabel,X_targets.astype(np.float32), Y_targets.astype(np.float32), interpolation=cv2.INTER_NEAREST, borderMode= cv2.BORDER_CONSTANT, borderValue=(0))
- #
- All_L = np.zeros(New_Index.shape)
- All_W = np.ones(New_Index.shape)
- #
- All_L = New_Index
- #
- gt_length_x = x2_source - x1_source
- gt_length_y = y2_source - y1_source
- #
- GT_y = (( GT_y / 256. * gt_length_y ) + y1_source - y1 ) * ( M / ( y2 - y1 ) )
- GT_x = (( GT_x / 256. * gt_length_x ) + x1_source - x1 ) * ( M / ( x2 - x1 ) )
- #
- GT_I[GT_y<0] = 0
- GT_I[GT_y>(M-1)] = 0
- GT_I[GT_x<0] = 0
- GT_I[GT_x>(M-1)] = 0
- #
- points_inside = GT_I>0
- GT_U = GT_U[points_inside]
- GT_V = GT_V[points_inside]
- GT_x = GT_x[points_inside]
- GT_y = GT_y[points_inside]
- GT_weights = GT_weights[points_inside]
- GT_I = GT_I[points_inside]
- #
- X_points[i, 0:len(GT_x)] = GT_x
- Y_points[i, 0:len(GT_y)] = GT_y
- Ind_points[i, 0:len(GT_I)] = i
- I_points[i, 0:len(GT_I)] = GT_I
- U_points[i, 0:len(GT_U)] = GT_U
- V_points[i, 0:len(GT_V)] = GT_V
- Uv_point_weights[i, 0:len(GT_weights)] = GT_weights
- #
- All_labels[i, :] = np.reshape(All_L.astype(np.int32), M ** 2)
- All_Weights[i, :] = np.reshape(All_W.astype(np.int32), M ** 2)
- ##
- else:
+ gt_box = boxes_from_polys[fg_polys_ind]
+ fg_x1, fg_y1, fg_x2, fg_y2 = roi_fg[0:4]
+ gt_x1, gt_y1, gt_x2, gt_y2 = gt_box[0:4]
+ fg_width = fg_x2 - fg_x1; fg_height = fg_y2 - fg_y1
+ gt_width = gt_x2 - gt_x1; gt_height = gt_y2 - gt_y1
+ fg_scale_w = float(M) / fg_width
+ fg_scale_h = float(M) / fg_height
+ gt_scale_w = 256. / gt_width
+ gt_scale_h = 256. / gt_height
+ # Sample M points evenly within the fg roi and scale the relative coordinates
+ # (to associated gt box) such that the bounding box size is 256 x 256.
+ x_targets = (np.arange(fg_x1, fg_x2, fg_width / M) - gt_x1) * gt_scale_w
+ y_targets = (np.arange(fg_y1, fg_y2, fg_height / M) - gt_y1) * gt_scale_h
+ # Construct 2D coordiante matrices
+ x_targets, y_targets = np.meshgrid(x_targets[:M], y_targets[:M])
+ ## Another implementation option (which results in similar performance)
+ # x_targets = (np.linspace(fg_x1, fg_x2, M, endpoint=True, dtype=np.float32) - gt_x1) * gt_scale_w
+ # y_targets = (np.linspace(fg_y1, fg_y2, M, endpoint=True, dtype=np.float32) - gt_y1) * gt_scale_h
+ # x_targets = (np.linspace(fg_x1, fg_x2, M, endpoint=False) - gt_x1) * gt_scale_w
+ # y_targets = (np.linspace(fg_y1, fg_y2, M, endpoint=False) - gt_y1) * gt_scale_h
+ # x_targets, y_targets = np.meshgrid(x_targets, y_targets)
+
+ # Map dense masks of size 256 x 256 to target heatmap of size M x M.
+ part_inds[i] = cv2.remap(
+ dp_masks, x_targets.astype(np.float32), y_targets.astype(np.float32),
+ interpolation=cv2.INTER_NEAREST,
+ borderMode=cv2.BORDER_CONSTANT, borderValue=(0)
+ )
+
+ # Scale annotated spatial coordinates from bbox of size 256 x 256 to target
+ # heatmap of size M x M.
+ dp_x = (dp_x / gt_scale_w + gt_x1 - fg_x1) * fg_scale_w
+ dp_y = (dp_y / gt_scale_h + gt_y1 - fg_y1) * fg_scale_h
+ # Set patch index of points outside the heatmap as 0 (background).
+ dp_I[dp_x < 0] = 0; dp_I[dp_x > (M - 1)] = 0
+ dp_I[dp_y < 0] = 0; dp_I[dp_y > (M - 1)] = 0
+ # Get body UV annotations of points inside the heatmap.
+ points_inside = dp_I > 0
+ dp_x = dp_x[points_inside]
+ dp_y = dp_y[points_inside]
+ dp_I = dp_I[points_inside]
+ dp_U = dp_U[points_inside]
+ dp_V = dp_V[points_inside]
+ # dp_UV_weights = dp_UV_weights[points_inside]
+
+ # Update body UV blobs
+ num_dp_points = len(dp_I)
+ # coords_xy[i, 0:num_dp_points, 0] = i # fg_roi index
+ coords_xy[i, 0:num_dp_points, 0] = dp_x
+ coords_xy[i, 0:num_dp_points, 1] = dp_y
+ I_points[i, 0:num_dp_points] = dp_I.astype(np.int32)
+ U_points[i, 0:num_dp_points] = dp_U
+ V_points[i, 0:num_dp_points] = dp_V
+ # Uv_point_weights[i, 0:len(dp_UV_weights)] = dp_UV_weights
+ else: # If there are no fg rois
+ # The network cannot handle empty blobs, so we must provide a blob.
+ # We simply take the first bg roi, give it an all 0's body UV annotations
+ # and label it with class zero (bg).
bg_inds = np.where(blobs['labels_int32'] == 0)[0]
- #
- if(len(bg_inds)==0):
+ # `rois_fg` is actually one background roi, but that's ok because ...
+ if len(bg_inds) == 0:
rois_fg = sampled_boxes[0].reshape((1, -1))
else:
rois_fg = sampled_boxes[bg_inds[0]].reshape((1, -1))
+ # Mark that the first roi has body UV annotation
+ roi_has_body_uv[0] = 1
+ # We give it all 0's blobs
+ part_inds = blob_utils.zeros((1, M, M), int32=True)
+ part_inds_weights = blob_utils.zeros((1, M, M), int32=False)
+ coords_xy = blob_utils.zeros((1, 196, 2), int32=False)
+ I_points = blob_utils.zeros((1, 196), int32=True)
+ U_points = blob_utils.zeros((1, 196), int32=False)
+ V_points = blob_utils.zeros((1, 196), int32=False)
+ # Uv_point_weights = blob_utils.zeros((1, 196), int32=False)
- roi_has_mask[0] = 1
- #
- X_points = blob_utils.zeros((1, 196), int32=False)
- Y_points = blob_utils.zeros((1, 196), int32=False)
- Ind_points = blob_utils.zeros((1, 196), int32=True)
- I_points = blob_utils.zeros((1,196), int32=True)
- U_points = blob_utils.zeros((1, 196), int32=False)
- V_points = blob_utils.zeros((1, 196), int32=False)
- Uv_point_weights = blob_utils.zeros((1, 196), int32=False)
- #
- All_labels = -blob_utils.ones((1, M ** 2), int32=True) * 0 ## zeros
- All_Weights = -blob_utils.ones((1, M ** 2), int32=True) * 0 ## zeros
- #
+ # Scale rois_fg and format as (batch_idx, x1, y1, x2, y2)
rois_fg *= im_scale
repeated_batch_idx = batch_idx * blob_utils.ones((rois_fg.shape[0], 1))
rois_fg = np.hstack((repeated_batch_idx, rois_fg))
- #
- K = cfg.BODY_UV_RCNN.NUM_PATCHES
- #
- U_points = np.tile( U_points , [1,K+1] )
- V_points = np.tile( V_points , [1,K+1] )
- Uv_Weight_Points = np.zeros(U_points.shape)
- #
- for jjj in xrange(1,K+1):
- Uv_Weight_Points[ : , jjj * I_points.shape[1] : (jjj+1) * I_points.shape[1] ] = ( I_points == jjj ).astype(np.float32)
- #
- ################
- # Update blobs dict with Mask R-CNN blobs
- ###############
- #
- blobs['body_uv_rois'] = np.array(rois_fg)
- blobs['roi_has_body_uv_int32'] = np.array(roi_has_mask).astype(np.int32)
- ##
- blobs['body_uv_ann_labels'] = np.array(All_labels).astype(np.int32)
- blobs['body_uv_ann_weights'] = np.array(All_Weights).astype(np.float32)
- #
- ##########################
- blobs['body_uv_X_points'] = X_points.astype(np.float32)
- blobs['body_uv_Y_points'] = Y_points.astype(np.float32)
- blobs['body_uv_Ind_points'] = Ind_points.astype(np.float32)
- blobs['body_uv_I_points'] = I_points.astype(np.float32)
- blobs['body_uv_U_points'] = U_points.astype(np.float32) #### VERY IMPORTANT : These are switched here :
- blobs['body_uv_V_points'] = V_points.astype(np.float32)
- blobs['body_uv_point_weights'] = Uv_Weight_Points.astype(np.float32)
- ###################
-
-
-
+ # Create body UV blobs for all patches (including background)
+ K = cfg.BODY_UV_RCNN.NUM_PATCHES + 1
+ # Construct U/V_points blobs for all patches by repeating it #num_patches times.
+ # Shape: (#rois, 196, K)
+ U_points = np.repeat(U_points[:, :, np.newaxis], K, axis=-1)
+ V_points = np.repeat(V_points[:, :, np.newaxis], K, axis=-1)
+ uv_point_weights = np.zeros_like(U_points)
+ # Set binary weights for UV targets in each patch
+ for i in np.arange(1, K):
+ uv_point_weights[:, :, i] = (I_points == i).astype(np.float32)
+ # Update blobs dict with body UV blobs
+ blobs['body_uv_rois'] = rois_fg
+ blobs['roi_has_body_uv_int32'] = roi_has_body_uv # shape: (#rois,)
+ blobs['body_uv_parts'] = part_inds # shape: (#rois, M, M)
+ blobs['body_uv_parts_weights'] = part_inds_weights
+ blobs['body_uv_coords_xy'] = coords_xy.reshape(-1, 2) # shape: (#rois * 196, 2)
+ blobs['body_uv_I_points'] = I_points.reshape(-1, 1) # shape: (#rois * 196, 1)
+ blobs['body_uv_U_points'] = U_points # shape: (#rois, 196, K)
+ blobs['body_uv_V_points'] = V_points
+ blobs['body_uv_point_weights'] = uv_point_weights
diff --git a/detectron/roi_data/fast_rcnn.py b/detectron/roi_data/fast_rcnn.py
index 2635974..153ce61 100644
--- a/detectron/roi_data/fast_rcnn.py
+++ b/detectron/roi_data/fast_rcnn.py
@@ -41,7 +41,6 @@ def get_fast_rcnn_blob_names(is_training=True):
# labels_int32 blob: R categorical labels in [0, ..., K] for K
# foreground classes plus background
blob_names += ['labels_int32']
- if is_training:
# bbox_targets blob: R bounding-box regression targets with 4
# targets per class
blob_names += ['bbox_targets']
@@ -81,19 +80,32 @@ def get_fast_rcnn_blob_names(is_training=True):
########################
if is_training and cfg.MODEL.BODY_UV_ON:
+ # 'body_uv_rois': RoIs sampled for training the body UV estimation branch.
+ # Shape is (#fg_rois, 5) in format (batch_idx, x1, y1, x2, y2).
blob_names += ['body_uv_rois']
+ # 'roi_has_body_uv': binary labels for the RoIs specified in 'rois'
+ # indicating if each RoI has a body or not. Shape is (#rois).
blob_names += ['roi_has_body_uv_int32']
- #########
- # ###################################################
- blob_names += ['body_uv_ann_labels']
- blob_names += ['body_uv_ann_weights']
- # #################################################
- blob_names += ['body_uv_X_points']
- blob_names += ['body_uv_Y_points']
- blob_names += ['body_uv_Ind_points']
+ # 'body_uv_parts': index of part in [0, ..., S] where S is the number of
+ # semantic parts used to sample body UV points for the RoIs specified in
+ # 'body_uv_rois'. Shape is (#rois, M, M) where M is the heat map size.
+ blob_names += ['body_uv_parts']
+ # 'body_uv_parts_weights': weight assigned to each target in 'body_uv_parts'.
+ # Shape is (#rois, M, M). Used in SpatialSoftmaxWithLoss.
+ blob_names += ['body_uv_parts_weights']
+ # 'body_uv_coords_xy': 2D spatial coordinates of collected points on
+ # the image. Shape is (#rois * 196, 2) in format (dp_x, dp_y).
+ # Used in PoolPointsInterp.
+ blob_names += ['body_uv_coords_xy']
+ # 'body_uv_I_points': surface patch indices in [0, ..., K] for K patches
+ # plus background. Shape is (#rois * 196, 1). Used in SoftmaxWithLoss.
blob_names += ['body_uv_I_points']
+ # 'body_uv_U/V_points': UV coordinates of collected points in each patch.
+ # Shape is (#rois, 196, K). Used in PoolPointsInterp and SmoothL1Loss.
blob_names += ['body_uv_U_points']
blob_names += ['body_uv_V_points']
+ # 'body_uv_point_weights': weight assigned to each target in
+ # 'body_uv_U/V_points'. Shape is (#rois, 196, K). Used in SmoothL1Loss.
blob_names += ['body_uv_point_weights']
if cfg.FPN.FPN_ON and cfg.FPN.MULTILEVEL_ROIS:
@@ -173,7 +185,7 @@ def _sample_rois(roidb, im_scale, batch_idx):
# against there being fewer than desired)
bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, bg_inds.size)
- # Sample foreground regions without replacement
+ # Sample background regions without replacement
if bg_inds.size > 0:
bg_inds = npr.choice(
bg_inds, size=bg_rois_per_this_image, replace=False
diff --git a/detectron/roi_data/rpn.py b/detectron/roi_data/rpn.py
index 63b0166..0aa8266 100644
--- a/detectron/roi_data/rpn.py
+++ b/detectron/roi_data/rpn.py
@@ -113,7 +113,9 @@ def add_rpn_blobs(blobs, im_scales, roidb):
valid_keys = [
'has_visible_keypoints', 'boxes', 'segms', 'seg_areas', 'gt_classes',
- 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints','flipped', 'ignore_UV_body','dp_x','dp_y','dp_I','dp_U','dp_V','dp_masks' ]
+ 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints', 'flipped',
+ 'ignore_UV_body', 'dp_x', 'dp_y', 'dp_I', 'dp_U', 'dp_V', 'dp_masks'
+ ]
minimal_roidb = [{} for _ in range(len(roidb))]
for i, e in enumerate(roidb):
for k in valid_keys:
diff --git a/detectron/utils/densepose_methods.py b/detectron/utils/densepose_methods.py
index 4d1dcf9..3c19aa9 100644
--- a/detectron/utils/densepose_methods.py
+++ b/detectron/utils/densepose_methods.py
@@ -3,141 +3,137 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+##############################################################################
+"""DensePose utilities."""
+
+from scipy.io import loadmat
+import os.path as osp
import numpy as np
-import copy
-import cv2
-from scipy.io import loadmat
-import scipy.spatial.distance
-import os
+import scipy.spatial.distance as ssd
+import pycocotools.mask as mask_util
+
+
+def GetDensePoseMask(Polys, num_parts=14):
+ """Get dense masks from the encoded masks."""
+ MaskGen = np.zeros((256, 256), dtype=np.int32)
+ for i in range(1, num_parts + 1):
+ if Polys[i - 1]:
+ current_mask = mask_util.decode(Polys[i - 1])
+ MaskGen[current_mask > 0] = i
+ return MaskGen
class DensePoseMethods:
def __init__(self):
- #
- ALP_UV = loadmat( os.path.join(os.path.dirname(__file__), '../../DensePoseData/UV_data/UV_Processed.mat') )
- self.FaceIndices = np.array( ALP_UV['All_FaceIndices']).squeeze()
- self.FacesDensePose = ALP_UV['All_Faces']-1
+ ALP_UV = loadmat(
+ osp.join(osp.dirname(__file__), '../../DensePoseData/UV_data/UV_Processed.mat')
+ )
+ self.FaceIndices = np.array(ALP_UV['All_FaceIndices']).squeeze()
+ self.FacesDensePose = ALP_UV['All_Faces'] - 1
self.U_norm = ALP_UV['All_U_norm'].squeeze()
self.V_norm = ALP_UV['All_V_norm'].squeeze()
- self.All_vertices = ALP_UV['All_vertices'][0]
- ## Info to compute symmetries.
- self.SemanticMaskSymmetries = [0,1,3,2,5,4,7,6,9,8,11,10,13,12,14]
- self.Index_Symmetry_List = [1,2,4,3,6,5,8,7,10,9,12,11,14,13,16,15,18,17,20,19,22,21,24,23];
- UV_symmetry_filename = os.path.join(os.path.dirname(__file__), '../../DensePoseData/UV_data/UV_symmetry_transforms.mat')
- self.UV_symmetry_transformations = loadmat( UV_symmetry_filename )
-
-
- def get_symmetric_densepose(self,I,U,V,x,y,Mask):
- ### This is a function to get the mirror symmetric UV labels.
- Labels_sym= np.zeros(I.shape)
- U_sym= np.zeros(U.shape)
- V_sym= np.zeros(V.shape)
- ###
- for i in ( range(24)):
- if i+1 in I:
- Labels_sym[I == (i+1)] = self.Index_Symmetry_List[i]
- jj = np.where(I == (i+1))
- ###
- U_loc = (U[jj]*255).astype(np.int64)
- V_loc = (V[jj]*255).astype(np.int64)
- ###
- V_sym[jj] = self.UV_symmetry_transformations['V_transforms'][0,i][V_loc,U_loc]
- U_sym[jj] = self.UV_symmetry_transformations['U_transforms'][0,i][V_loc,U_loc]
- ##
- Mask_flip = np.fliplr(Mask)
- Mask_flipped = np.zeros(Mask.shape)
- #
- for i in ( range(14)):
- Mask_flipped[Mask_flip == (i+1)] = self.SemanticMaskSymmetries[i+1]
- #
- [y_max , x_max ] = Mask_flip.shape
- y_sym = y
- x_sym = x_max-x
- #
- return Labels_sym , U_sym , V_sym , x_sym , y_sym , Mask_flipped
-
-
-
- def barycentric_coordinates_exists(self,P0, P1, P2, P):
- u = P1 - P0
- v = P2 - P0
- w = P - P0
- #
- vCrossW = np.cross(v,w)
+ self.All_vertices = ALP_UV['All_vertices'][0]
+ self.SemanticMaskSymmetries = [
+ 0, 1, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 14
+ ]
+ self.Index_Symmetry_List = [
+ 1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, 23
+ ]
+ self.UV_symmetry_transformations = loadmat(
+ osp.join(osp.dirname(__file__), '../../DensePoseData/UV_data/UV_symmetry_transforms.mat')
+ )
+
+ def get_symmetric_densepose(self, I, U, V, x, y, mask):
+ """Get the mirror symmetric UV annotations"""
+ symm_I = np.zeros_like(I)
+ symm_U = np.zeros_like(U)
+ symm_V = np.zeros_like(V)
+ for i in range(24):
+ inds = np.where(I == (i + 1))[0]
+ if len(inds) > 0:
+ symm_I[inds] = self.Index_Symmetry_List[i]
+ loc_U = (U[inds] * 255).astype(np.int32)
+ loc_V = (V[inds] * 255).astype(np.int32)
+ symm_U[inds] = self.UV_symmetry_transformations['U_transforms'][0, i][loc_V, loc_U]
+ symm_V[inds] = self.UV_symmetry_transformations['V_transforms'][0, i][loc_V, loc_U]
+
+ flip_mask = np.fliplr(mask)
+ symm_mask = np.zeros_like(mask)
+ for i in range(1, 15):
+ symm_mask[flip_mask == i] = self.SemanticMaskSymmetries[i]
+ x_max = flip_mask.shape[1]
+ symm_x = x_max - x
+ symm_y = y
+ return symm_I, symm_U, symm_V, symm_x, symm_y, symm_mask
+
+ def barycentric_coordinates_exists(self, P0, P1, P2, P):
+ u = P1 - P0; v = P2 - P0; w = P - P0
+ vCrossW = np.cross(v, w)
vCrossU = np.cross(v, u)
- if (np.dot(vCrossW, vCrossU) < 0):
- return False;
- #
+ if np.dot(vCrossW, vCrossU) < 0:
+ return False
+
uCrossW = np.cross(u, w)
uCrossV = np.cross(u, v)
- #
if (np.dot(uCrossW, uCrossV) < 0):
- return False;
- #
- denom = np.sqrt((uCrossV**2).sum())
- r = np.sqrt((vCrossW**2).sum())/denom
- t = np.sqrt((uCrossW**2).sum())/denom
- #
- return((r <=1) & (t <= 1) & (r + t <= 1))
-
- def barycentric_coordinates(self,P0, P1, P2, P):
- u = P1 - P0
- v = P2 - P0
- w = P - P0
- #
- vCrossW = np.cross(v,w)
+ return False
+
+ denom = np.sqrt((uCrossV ** 2).sum())
+ r = np.sqrt((vCrossW ** 2).sum()) / denom
+ t = np.sqrt((uCrossW ** 2).sum()) / denom
+ return ((r <= 1) & (t <= 1) & (r + t <= 1))
+
+ def barycentric_coordinates(self, P0, P1, P2, P):
+ u = P1 - P0; v = P2 - P0; w = P - P0
+ vCrossW = np.cross(v, w)
vCrossU = np.cross(v, u)
- #
+ if np.dot(vCrossW, vCrossU) < 0:
+ return -1, -1, -1
uCrossW = np.cross(u, w)
uCrossV = np.cross(u, v)
- #
- denom = np.sqrt((uCrossV**2).sum())
- r = np.sqrt((vCrossW**2).sum())/denom
- t = np.sqrt((uCrossW**2).sum())/denom
- #
- return(1-(r+t),r,t)
-
- def IUV2FBC( self, I_point , U_point, V_point):
- P = [ U_point , V_point , 0 ]
- FaceIndicesNow = np.where( self.FaceIndices == I_point )
- FacesNow = self.FacesDensePose[FaceIndicesNow]
- #
- P_0 = np.vstack( (self.U_norm[FacesNow][:,0], self.V_norm[FacesNow][:,0], np.zeros(self.U_norm[FacesNow][:,0].shape))).transpose()
- P_1 = np.vstack( (self.U_norm[FacesNow][:,1], self.V_norm[FacesNow][:,1], np.zeros(self.U_norm[FacesNow][:,1].shape))).transpose()
- P_2 = np.vstack( (self.U_norm[FacesNow][:,2], self.V_norm[FacesNow][:,2], np.zeros(self.U_norm[FacesNow][:,2].shape))).transpose()
- #
-
- for i, [P0,P1,P2] in enumerate( zip(P_0,P_1,P_2)) :
- if(self.barycentric_coordinates_exists(P0, P1, P2, P)):
- [bc1,bc2,bc3] = self.barycentric_coordinates(P0, P1, P2, P)
- return(FaceIndicesNow[0][i],bc1,bc2,bc3)
- #
+ if np.dot(uCrossW, uCrossV) < 0:
+ return -1, -1, -1
+ denom = np.sqrt((uCrossV ** 2).sum())
+ r = np.sqrt((vCrossW ** 2).sum()) / denom
+ t = np.sqrt((uCrossW ** 2).sum()) / denom
+ if ((r <= 1) & (t <= 1) & (r + t <= 1)):
+ return 1 - (r + t), r, t
+ else:
+ return -1, -1, -1
+
+ def IUV2FBC(self, I_point, U_point, V_point):
+ """Convert IUV to FBC (faceIndex and barycentric coordinates)."""
+ P = [U_point, V_point, 0]
+ faceIndicesNow = np.where(self.FaceIndices == I_point)[0]
+ FacesNow = self.FacesDensePose[faceIndicesNow]
+ v0 = np.zeros_like(self.U_norm[FacesNow][:, 0])
+ P_0 = np.vstack((self.U_norm[FacesNow][:, 0], self.V_norm[FacesNow][:, 0], v0)).transpose()
+ P_1 = np.vstack((self.U_norm[FacesNow][:, 1], self.V_norm[FacesNow][:, 1], v0)).transpose()
+ P_2 = np.vstack((self.U_norm[FacesNow][:, 2], self.V_norm[FacesNow][:, 2], v0)).transpose()
+
+ for i, [P0, P1, P2] in enumerate(zip(P_0, P_1, P_2)) :
+ bc1, bc2, bc3 = self.barycentric_coordinates(P0, P1, P2, P)
+ if bc1 != -1:
+ return faceIndicesNow[i], bc1, bc2, bc3
+
# If the found UV is not inside any faces, select the vertex that is closest!
- #
- D1 = scipy.spatial.distance.cdist( np.array( [U_point,V_point])[np.newaxis,:] , P_0[:,0:2]).squeeze()
- D2 = scipy.spatial.distance.cdist( np.array( [U_point,V_point])[np.newaxis,:] , P_1[:,0:2]).squeeze()
- D3 = scipy.spatial.distance.cdist( np.array( [U_point,V_point])[np.newaxis,:] , P_2[:,0:2]).squeeze()
- #
- minD1 = D1.min()
- minD2 = D2.min()
- minD3 = D3.min()
- #
- if((minD1< minD2) & (minD1< minD3)):
- return( FaceIndicesNow[0][np.argmin(D1)] , 1.,0.,0. )
- elif((minD2< minD1) & (minD2< minD3)):
- return( FaceIndicesNow[0][np.argmin(D2)] , 0.,1.,0. )
+ D1 = ssd.cdist(np.array([U_point, V_point])[np.newaxis, :], P_0[:, 0:2]).squeeze()
+ D2 = ssd.cdist(np.array([U_point, V_point])[np.newaxis, :], P_1[:, 0:2]).squeeze()
+ D3 = ssd.cdist(np.array([U_point, V_point])[np.newaxis, :], P_2[:, 0:2]).squeeze()
+ minD1 = D1.min(); minD2 = D2.min(); minD3 = D3.min()
+ if ((minD1 < minD2) & (minD1 < minD3)):
+ return faceIndicesNow[np.argmin(D1)], 1., 0., 0.
+ elif ((minD2 < minD1) & (minD2 < minD3)):
+ return faceIndicesNow[np.argmin(D2)], 0., 1., 0.
else:
- return( FaceIndicesNow[0][np.argmin(D3)] , 0.,0.,1. )
-
-
- def FBC2PointOnSurface( self, FaceIndex, bc1,bc2,bc3,Vertices ):
- ##
- Vert_indices = self.All_vertices[self.FacesDensePose[FaceIndex]]-1
- ##
- p = Vertices[Vert_indices[0],:] * bc1 + \
- Vertices[Vert_indices[1],:] * bc2 + \
- Vertices[Vert_indices[2],:] * bc3
- ##
- return(p)
-
+ return faceIndicesNow[np.argmin(D3)], 0., 0., 1.
+
+ def FBC2PointOnSurface(self, face_ind, bc1, bc2, bc3, vertices):
+ """Use FBC to get 3D coordinates on the surface."""
+ Vert_indices = self.All_vertices[self.FacesDensePose[face_ind]] - 1
+ # p = vertices[Vert_indices[0], :] * bc1 + \
+ # vertices[Vert_indices[1], :] * bc2 + \
+ # vertices[Vert_indices[2], :] * bc3
+ p = np.matmul(np.array([[bc1, bc2, bc3]]), vertices[Vert_indices]).squeeze()
+ return p
diff --git a/detectron/utils/io.py b/detectron/utils/io.py
index 3ec5e22..2ac2fdc 100644
--- a/detectron/utils/io.py
+++ b/detectron/utils/io.py
@@ -35,6 +35,15 @@ def cache_url(url_or_file, cache_dir):
path to the cached file. If the argument is not a URL, simply return it as
is.
"""
+ if re.match(r'^\$', url_or_file, re.IGNORECASE) is not None:
+ url_or_file = os.path.expandvars(url_or_file)
+ assert os.path.exists(url_or_file)
+ return url_or_file
+ elif re.match(r'^~', url_or_file, re.IGNORECASE) is not None:
+ url_or_file = os.path.expanduser(url_or_file)
+ assert os.path.exists(url_or_file)
+ return url_or_file
+
is_url = re.match(r'^(?:http)s?://', url_or_file, re.IGNORECASE) is not None
if not is_url:
@@ -42,8 +51,8 @@ def cache_url(url_or_file, cache_dir):
#
url = url_or_file
#
- Len_filename = len( url.split('/')[-1] )
- BASE_URL = url[0:-Len_filename-1]
+ Len_filename = len(url.split('/')[-1])
+ BASE_URL = url[0:-Len_filename - 1]
#
cache_file_path = url.replace(BASE_URL, cache_dir)
if os.path.exists(cache_file_path):
diff --git a/detectron/utils/segms.py b/detectron/utils/segms.py
index bf1ac3f..9967a24 100644
--- a/detectron/utils/segms.py
+++ b/detectron/utils/segms.py
@@ -20,19 +20,9 @@
from __future__ import unicode_literals
import numpy as np
-
import pycocotools.mask as mask_util
-def GetDensePoseMask(Polys):
- MaskGen = np.zeros([256,256])
- for i in range(1,15):
- if(Polys[i-1]):
- current_mask = mask_util.decode(Polys[i-1])
- MaskGen[current_mask>0] = i
- return MaskGen
-
-
def flip_segms(segms, height, width):
"""Left/right flip each mask in a list of masks."""
def _flip_poly(poly, width):
diff --git a/detectron/utils/vis.py b/detectron/utils/vis.py
index 31b2e79..13556cd 100644
--- a/detectron/utils/vis.py
+++ b/detectron/utils/vis.py
@@ -376,45 +376,54 @@ def vis_one_image(
line, color=colors[len(kp_lines) + 1], linewidth=1.0,
alpha=0.7)
- # DensePose Visualization Starts!!
- ## Get full IUV image out
+ ### DensePose Visualization Starts!!
+ # get full IUV image outputs for all bboxes
IUV_fields = body_uv[1]
- #
- All_Coords = np.zeros(im.shape)
- All_inds = np.zeros([im.shape[0],im.shape[1]])
- K = 26
- ##
- inds = np.argsort(boxes[:,4])
- ##
+ # initialize IUV output and INDS output images with zeros
+ All_coords = np.zeros(im.shape, dtype=np.float32) # shape: (im_height, im_width, 3)
+ All_inds = np.zeros([im.shape[0], im.shape[1]], dtype=np.float32) # shape: (im_height, im_width)
+
+ # display in smallest to largest class scores order, however, this may cause sharpness in some body parts
+ # due to the output of an inaccurate bbox with lower score will not be overlapped by the output of a more
+ # precise bbox with higher score which will be discarded.
+ inds = np.argsort(boxes[:, 4])
for i, ind in enumerate(inds):
- entry = boxes[ind,:]
- if entry[4] > 0.65:
- entry=entry[0:4].astype(int)
- ####
+ score = boxes[ind, 4]
+ bbox = boxes[ind, :4]
+ if score > 0.65:
+ # top left corner (x1, y1) of current bbox in image space
+ x1, y1 = boxes[ind, :2].astype(int)
+ # get IUV output for current bbox
output = IUV_fields[ind]
- ####
- All_Coords_Old = All_Coords[ entry[1] : entry[1]+output.shape[1],entry[0]:entry[0]+output.shape[2],:]
- All_Coords_Old[All_Coords_Old==0]=output.transpose([1,2,0])[All_Coords_Old==0]
- All_Coords[ entry[1] : entry[1]+output.shape[1],entry[0]:entry[0]+output.shape[2],:]= All_Coords_Old
- ###
- CurrentMask = (output[0,:,:]>0).astype(np.float32)
- All_inds_old = All_inds[ entry[1] : entry[1]+output.shape[1],entry[0]:entry[0]+output.shape[2]]
- All_inds_old[All_inds_old==0] = CurrentMask[All_inds_old==0]*i
- All_inds[ entry[1] : entry[1]+output.shape[1],entry[0]:entry[0]+output.shape[2]] = All_inds_old
- #
- All_Coords[:,:,1:3] = 255. * All_Coords[:,:,1:3]
- All_Coords[All_Coords>255] = 255.
- All_Coords = All_Coords.astype(np.uint8)
+ out_height, out_width = output.shape[1:3]
+
+ # first, locate the region of current bbox on final IUV output image
+ All_coords_tmp = All_coords[y1:y1 + out_height, x1:x1 + out_width]
+ # then, extract IUV output of pixels within this bbox in which have not been filled with IUV of other bbox
+ All_coords_tmp[All_coords_tmp == 0] = output.transpose([1, 2, 0])[All_coords_tmp == 0]
+ # update final IUV output image
+ All_coords[y1:y1 + out_height, x1:x1 + out_width] = All_coords_tmp
+
+ # get (distinct) human-body FG mask indices for each bbox
+ index_UV = output[0] # predicted part index: 0 ~ 24
+ CurrentMask = (index_UV > 0).astype(np.float32)
+ All_inds_tmp = All_inds[y1:y1 + out_height, x1:x1 + out_width]
+ All_inds_tmp[All_inds_tmp == 0] = CurrentMask[All_inds_tmp == 0] * (i + 1) # avoid `i` starting with 0
+ All_inds[y1:y1 + out_height, x1:x1 + out_width] = All_inds_tmp
+
+ # scale predicted UV coordinates to [0, 255]
+ All_coords[:, :, 1:3] = All_coords[:, :, 1:3] * 255.
+ All_coords[All_coords > 255] = 255.
+ All_coords = All_coords.astype(np.uint8)
All_inds = All_inds.astype(np.uint8)
- #
- IUV_SaveName = os.path.basename(im_name).split('.')[0]+'_IUV.png'
- INDS_SaveName = os.path.basename(im_name).split('.')[0]+'_INDS.png'
- cv2.imwrite(os.path.join(output_dir, '{}'.format(IUV_SaveName)), All_Coords )
- cv2.imwrite(os.path.join(output_dir, '{}'.format(INDS_SaveName)), All_inds )
- print('IUV written to: ' , os.path.join(output_dir, '{}'.format(IUV_SaveName)) )
- ###
+ # save IUV images into files
+ IUV_SaveName = os.path.basename(im_name).split('.')[0] + '_IUV.png'
+ INDS_SaveName = os.path.basename(im_name).split('.')[0] + '_INDS.png'
+ cv2.imwrite(os.path.join(output_dir, '{}'.format(IUV_SaveName)), All_coords)
+ cv2.imwrite(os.path.join(output_dir, '{}'.format(INDS_SaveName)), All_inds)
+ print('IUV written to: ', os.path.join(output_dir, '{}'.format(IUV_SaveName)))
### DensePose Visualization Done!!
- #
+
output_name = os.path.basename(im_name) + '.' + ext
fig.savefig(os.path.join(output_dir, '{}'.format(output_name)), dpi=dpi)
plt.close('all')