diff --git a/INSTALL.md b/INSTALL.md index 5a722fb..198e7c7 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -14,7 +14,7 @@ The DensePose-RCNN system is implemented within the [`detectron`](https://github ## Caffe2 -To install Caffe2 with CUDA support, follow the [installation instructions](https://caffe2.ai/docs/getting-started.html) from the [Caffe2 website](https://caffe2.ai/). **If you already have Caffe2 installed, make sure to update your Caffe2 to a version that includes the [Detectron module](https://github.com/caffe2/caffe2/tree/master/modules/detectron).** +To install Caffe2 with CUDA support, follow the [installation instructions](https://caffe2.ai/docs/getting-started.html) from the [Caffe2 website](https://caffe2.ai/). **If you already have Caffe2 installed, make sure to update your Caffe2 to a version that includes the [Detectron module](https://github.com/pytorch/pytorch/tree/master/modules/detectron).** Please ensure that your Caffe2 installation was successful before proceeding by running the following commands and checking their output as directed in the comments. @@ -130,7 +130,7 @@ coco ## Docker Image -We provide a [`Dockerfile`](docker/Dockerfile) that you can use to build a Densepose image on top of a Caffe2 image that satisfies the requirements outlined at the top. If you would like to use a Caffe2 image different from the one we use by default, please make sure that it includes the [Detectron module](https://github.com/caffe2/caffe2/tree/master/modules/detectron). +We provide a [`Dockerfile`](docker/Dockerfile) that you can use to build a Densepose image on top of a Caffe2 image that satisfies the requirements outlined at the top. If you would like to use a Caffe2 image different from the one we use by default, please make sure that it includes the [Detectron module](https://github.com/pytorch/pytorch/tree/master/modules/detectron). Build the image: diff --git a/README.md b/README.md index 33a1b1f..f5b6f03 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ _Rıza Alp Güler, Natalia Neverova, Iasonas Kokkinos_ [[`densepose.org`](https://densepose.org)] [[`arXiv`](https://arxiv.org/abs/1802.00434)] [[`BibTeX`](#CitingDensePose)] Dense human pose estimation aims at mapping all human pixels of an RGB image to the 3D surface of the human body. -DensePose-RCNN is implemented in the [Detectron](https://github.com/facebookresearch/Detectron) framework and is powered by [Caffe2](https://github.com/caffe2/caffe2). +DensePose-RCNN is implemented in the [Detectron](https://github.com/facebookresearch/Detectron) framework and is powered by [Caffe2](https://github.com/pytorch/pytorch/tree/master/caffe2).
diff --git a/detectron/core/config.py b/detectron/core/config.py index 79ee36f..e165e37 100644 --- a/detectron/core/config.py +++ b/detectron/core/config.py @@ -863,7 +863,7 @@ __C.BODY_UV_RCNN.UP_SCALE = -1 # Apply a ConvTranspose layer to the features prior to predicting the heatmaps -__C.KRCNN.USE_DECONV = False +__C.BODY_UV_RCNN.USE_DECONV = False # Channel dimension of the hidden representation produced by the ConvTranspose __C.BODY_UV_RCNN.DECONV_DIM = 256 # Use a ConvTranspose layer to predict the heatmaps @@ -876,6 +876,9 @@ # Number of patches in the dataset __C.BODY_UV_RCNN.NUM_PATCHES = -1 +# Number of semantic parts used to sample annotation points +__C.BODY_UV_RCNN.NUM_SEMANTIC_PARTS = 14 + # Number of stacked Conv layers in body UV head __C.BODY_UV_RCNN.NUM_STACKED_CONVS = 8 # Dimension of the hidden representation output by the body UV head diff --git a/detectron/core/test.py b/detectron/core/test.py index 877d803..4fc5854 100644 --- a/detectron/core/test.py +++ b/detectron/core/test.py @@ -948,7 +948,7 @@ def im_detect_body_uv(model, im_scale, boxes): # Removed squeeze calls due to singleton dimension issues CurAnnIndex = np.argmax(CurAnnIndex, axis=0) CurIndex_UV = np.argmax(CurIndex_UV, axis=0) - CurIndex_UV = CurIndex_UV * (CurAnnIndex>0).astype(np.float32) + CurIndex_UV = CurIndex_UV * (CurAnnIndex > 0).astype(np.float32) output = np.zeros([3, int(by), int(bx)], dtype=np.float32) output[0] = CurIndex_UV @@ -956,8 +956,8 @@ def im_detect_body_uv(model, im_scale, boxes): for part_id in range(1, K): CurrentU = CurU_uv[part_id] CurrentV = CurV_uv[part_id] - output[1, CurIndex_UV==part_id] = CurrentU[CurIndex_UV==part_id] - output[2, CurIndex_UV==part_id] = CurrentV[CurIndex_UV==part_id] + output[1, CurIndex_UV == part_id] = CurrentU[CurIndex_UV == part_id] + output[2, CurIndex_UV == part_id] = CurrentV[CurIndex_UV == part_id] outputs.append(output) num_classes = cfg.MODEL.NUM_CLASSES diff --git a/detectron/core/test_engine.py b/detectron/core/test_engine.py index dc7c92c..bc5d76a 100644 --- a/detectron/core/test_engine.py +++ b/detectron/core/test_engine.py @@ -374,7 +374,7 @@ def get_roidb_and_dataset(dataset_name, proposal_file, ind_range): def empty_results(num_classes, num_images): - """Return empty results lists for boxes, masks, and keypoints. + """Return empty results lists for boxes, masks, keypoints and body IUVs. Box detections are collected into: all_boxes[cls][image] = N x 5 array with columns (x1, y1, x2, y2, score) Instance mask predictions are collected into: @@ -386,8 +386,11 @@ def empty_results(num_classes, num_images): [x, y, logit, prob] (See: utils.keypoints.heatmaps_to_keypoints). Keypoints are recorded for person (cls = 1); they are in 1:1 correspondence with the boxes in all_boxes[cls][image]. - Body uv predictions are collected into: - TODO + Body IUV predictions are collected into: + all_bodys[cls][image] = [...] list of body IUV results that are in 1:1 + correspondence with the boxes in all_boxes['person'][image], each encoded + as a 3D array (3, int(bbox_height), int(bbox_width)) with the 3 rows + corresponding to [Index, U, V]. """ # Note: do not be tempted to use [[] * N], which gives N references to the # *same* empty list. diff --git a/detectron/datasets/densepose_cocoeval.py b/detectron/datasets/densepose_cocoeval.py index b082a12..6bff367 100644 --- a/detectron/datasets/densepose_cocoeval.py +++ b/detectron/datasets/densepose_cocoeval.py @@ -97,49 +97,52 @@ def __init__(self, cocoGt=None, cocoDt=None, iouType='segm', sigma=1.): self.sigma = sigma self.ignoreThrBB = 0.7 self.ignoreThrUV = 0.9 + self.num_parts = 24 # number of pre-defined body parts def _loadGEval(self): - print('Loading densereg GT..') prefix = os.path.dirname(__file__) + '/../../DensePoseData/eval_data/' - print(prefix) + print('Loading densereg GT from {}'.format(prefix)) SMPL_subdiv = loadmat(prefix + 'SMPL_subdiv.mat') self.PDIST_transform = loadmat(prefix + 'SMPL_SUBDIV_TRANSFORM.mat') + # 1-based index of geodesic distance matrix, + # range: [1, 27554], shape: (num_total_points=29408,) self.PDIST_transform = self.PDIST_transform['index'].squeeze() - UV = np.array([ - SMPL_subdiv['U_subdiv'], - SMPL_subdiv['V_subdiv'] - ]).squeeze() - ClosestVertInds = np.arange(UV.shape[1])+1 + # UV coordinates of all collected points on SMPL model, shape: (2, num_total_points) + UV = np.array([SMPL_subdiv['U_subdiv'], SMPL_subdiv['V_subdiv']]).squeeze() + # body part index (1 ~ 24), shape: (num_total_point,) + Part_ids = SMPL_subdiv['Part_ID_subdiv'].squeeze() + self.Part_ids = np.array(Part_ids) + # 1-based index of closest vertex index of each point, + # range: [1, num_total_points], shape: (num_total_points,) + ClosestVertInds = np.arange(UV.shape[1]) + 1 self.Part_UVs = [] self.Part_ClosestVertInds = [] - for i in np.arange(24): - self.Part_UVs.append( - UV[:, SMPL_subdiv['Part_ID_subdiv'].squeeze()==(i+1)] - ) - self.Part_ClosestVertInds.append( - ClosestVertInds[SMPL_subdiv['Part_ID_subdiv'].squeeze()==(i+1)] - ) + # UV coordinates and closest vertex indices of points in each part + for i in np.arange(1, self.num_parts + 1): + self.Part_UVs.append(UV[:, Part_ids == i]) + self.Part_ClosestVertInds.append(ClosestVertInds[Part_ids == i]) arrays = {} - f = h5py.File( prefix + 'Pdist_matrix.mat') + f = h5py.File(prefix + 'Pdist_matrix.mat') for k, v in f.items(): arrays[k] = np.array(v) + f.close() + # precomputed geodesic distances matrix with compact representation self.Pdist_matrix = arrays['Pdist_matrix'] - self.Part_ids = np.array( SMPL_subdiv['Part_ID_subdiv'].squeeze()) # Mean geodesic distances for parts. - self.Mean_Distances = np.array( [0, 0.351, 0.107, 0.126,0.237,0.173,0.142,0.128,0.150] ) + self.Mean_Distances = np.array([0, 0.351, 0.107, 0.126, 0.237, 0.173, 0.142, 0.128, 0.150]) # Coarse Part labels. - self.CoarseParts = np.array( [ 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, - 6, 6, 6, 6, 7, 7, 7, 7, 8, 8] ) - - print('Loaded') + self.CoarseParts = np.array( + [0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8] + ) + + print('densereg GT loaded') def _prepare(self): ''' Prepare ._gts and ._dts for evaluation based on params :return: None ''' - def _toMask(anns, coco): # modify ann['segmentation'] by reference for ann in anns: @@ -168,11 +171,11 @@ def _checkIgnore(dt, iregion): return True bb = np.array(dt['bbox']).astype(np.int) - x1,y1,x2,y2 = bb[0],bb[1],bb[0]+bb[2],bb[1]+bb[3] - x2 = min([x2,iregion.shape[1]]) - y2 = min([y2,iregion.shape[0]]) + x1, y1, x2, y2 = bb[0], bb[1], bb[0] + bb[2], bb[1] + bb[3] + x2 = min([x2, iregion.shape[1]]) + y2 = min([y2, iregion.shape[0]]) - if bb[2]* bb[3] == 0: + if bb[2] * bb[3] == 0: return False crop_iregion = iregion[y1:y2, x1:x2] @@ -181,11 +184,11 @@ def _checkIgnore(dt, iregion): return True if not 'uv' in dt.keys(): # filtering boxes - return crop_iregion.sum()/bb[2]/bb[3] < self.ignoreThrBB + return crop_iregion.sum() / bb[2] / bb[3] < self.ignoreThrBB # filtering UVs ignoremask = np.require(crop_iregion, requirements=['F']) - uvmask = np.require(np.asarray(dt['uv'][0]>0), dtype = np.uint8, + uvmask = np.require(np.asarray(dt['uv'][0] > 0), dtype=np.uint8, requirements=['F']) uvmask_ = maskUtils.encode(uvmask) ignoremask_ = maskUtils.encode(ignoremask) @@ -195,13 +198,13 @@ def _checkIgnore(dt, iregion): p = self.params if p.useCats: - gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) - dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) + gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) + dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)) else: - gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) - dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) + gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) + dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) - # if iouType == 'uv', add point gt annotations + # add point gt annotations if iouType == 'uv' if p.iouType == 'uv': self._loadGEval() @@ -217,7 +220,7 @@ def _checkIgnore(dt, iregion): if p.iouType == 'keypoints': gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore'] if p.iouType == 'uv': - gt['ignore'] = ('dp_x' in gt)==0 + gt['ignore'] = 'dp_x' not in gt self._gts = defaultdict(list) # gt for evaluation self._dts = defaultdict(list) # dt for evaluation @@ -233,8 +236,8 @@ def _checkIgnore(dt, iregion): if _checkIgnore(dt, self._igrgns[dt['image_id']]): self._dts[dt['image_id'], dt['category_id']].append(dt) - self.evalImgs = defaultdict(list) # per-image per-category evaluation results - self.eval = {} # accumulated evaluation results + self.evalImgs = defaultdict(list) # per-image per-category evaluation results + self.eval = {} # accumulated evaluation results def evaluate(self): ''' @@ -253,7 +256,7 @@ def evaluate(self): if p.useCats: p.catIds = list(np.unique(p.catIds)) p.maxDets = sorted(p.maxDets) - self.params=p + self.params = p self._prepare() # loop through images, area range, max detection number @@ -264,37 +267,40 @@ def evaluate(self): elif p.iouType == 'keypoints': computeIoU = self.computeOks elif p.iouType == 'uv': - computeIoU = self.computeOgps + computeIoU = self.computeGps - self.ious = {(imgId, catId): computeIoU(imgId, catId) \ - for imgId in p.imgIds - for catId in catIds} + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds + } evaluateImg = self.evaluateImg maxDet = p.maxDets[-1] - self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet) - for catId in catIds - for areaRng in p.areaRng - for imgId in p.imgIds - ] + self.evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] self._paramsEval = copy.deepcopy(self.params) toc = time.time() - print('DONE (t={:0.2f}s).'.format(toc-tic)) + print('DONE (t={:0.2f}s).'.format(toc - tic)) def computeIoU(self, imgId, catId): p = self.params if p.useCats: - gt = self._gts[imgId,catId] - dt = self._dts[imgId,catId] + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] else: - gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]] - dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]] - if len(gt) == 0 and len(dt) ==0: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + if len(gt) == 0 and len(dt) == 0: return [] inds = np.argsort([-d['score'] for d in dt], kind='mergesort') dt = [dt[i] for i in inds] if len(dt) > p.maxDets[-1]: - dt=dt[0:p.maxDets[-1]] + dt = dt[0:p.maxDets[-1]] if p.iouType == 'segm': g = [g['segmentation'] for g in gt] @@ -312,7 +318,7 @@ def computeIoU(self, imgId, catId): def computeOks(self, imgId, catId): p = self.params - # dimention here should be Nxm + # dimention here should be N x m gts = self._gts[imgId, catId] dts = self._dts[imgId, catId] inds = np.argsort([-d['score'] for d in dts], kind='mergesort') @@ -323,8 +329,8 @@ def computeOks(self, imgId, catId): if len(gts) == 0 or len(dts) == 0: return [] ious = np.zeros((len(dts), len(gts))) - sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0 - vars = (sigmas * 2)**2 + sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0 + vars = (sigmas * 2) ** 2 k = len(sigmas) # compute oks between each detection and ground truth object for j, gt in enumerate(gts): @@ -338,83 +344,91 @@ def computeOks(self, imgId, catId): for i, dt in enumerate(dts): d = np.array(dt['keypoints']) xd = d[0::3]; yd = d[1::3] - if k1>0: + if k1 > 0: # measure the per-keypoint distance if keypoints visible dx = xd - xg dy = yd - yg else: # measure minimum distance to keypoints in (x0,y0) & (x1,y1) z = np.zeros((k)) - dx = np.max((z, x0-xd), axis=0) + np.max((z, xd-x1), axis=0) - dy = np.max((z, y0-yd), axis=0) + np.max((z, yd-y1), axis=0) - e = (dx**2 + dy**2) / vars / (gt['area'] + np.spacing(1)) / 2 + dx = np.max((z, x0 - xd), axis=0) + np.max((z, xd - x1), axis=0) + dy = np.max((z, y0 - yd), axis=0) + np.max((z, yd - y1), axis=0) + e = (dx ** 2 + dy ** 2) / vars / (gt['area'] + np.spacing(1)) / 2 if k1 > 0: - e=e[vg > 0] + e = e[vg > 0] ious[i, j] = np.sum(np.exp(-e)) / e.shape[0] return ious - def computeOgps(self, imgId, catId): + def computeGps(self, imgId, catId): p = self.params - # dimention here should be Nxm - g = self._gts[imgId, catId] - d = self._dts[imgId, catId] - inds = np.argsort([-d_['score'] for d_ in d], kind='mergesort') - d = [d[i] for i in inds] - if len(d) > p.maxDets[-1]: - d = d[0:p.maxDets[-1]] - # if len(gts) == 0 and len(dts) == 0: - if len(g) == 0 or len(d) == 0: + # dimention here should be N x m + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + inds = np.argsort([-d['score'] for d in dt], kind='mergesort') + dt = [dt[i] for i in inds] + if len(dt) > p.maxDets[-1]: + dt = dt[0:p.maxDets[-1]] + + if len(gt) == 0 or len(dt) == 0: return [] - ious = np.zeros((len(d), len(g))) - # compute opgs between each detection and ground truth object - sigma = self.sigma #0.255 # dist = 0.3m corresponds to ogps = 0.5 - # 1 # dist = 0.3m corresponds to ogps = 0.96 - # 1.45 # dist = 1.7m (person height) corresponds to ogps = 0.5) - for j, gt in enumerate(g): - if not gt['ignore']: - g_ = gt['bbox'] - for i, dt in enumerate(d): - # - dx = dt['bbox'][3] - dy = dt['bbox'][2] - dp_x = np.array( gt['dp_x'] )*g_[2]/255. - dp_y = np.array( gt['dp_y'] )*g_[3]/255. - px = ( dp_y + g_[1] - dt['bbox'][1]).astype(np.int) - py = ( dp_x + g_[0] - dt['bbox'][0]).astype(np.int) - # - pts = np.zeros(len(px)) - pts[px>=dx] = -1; pts[py>=dy] = -1 - pts[px<0] = -1; pts[py<0] = -1 - #print(pts.shape) - if len(pts) < 1: - ogps = 0. - elif np.max(pts) == -1: - ogps = 0. - else: - px[pts==-1] = 0; py[pts==-1] = 0; - ipoints = dt['uv'][0, px, py] - upoints = dt['uv'][1, px, py]/255. # convert from uint8 by /255. - vpoints = dt['uv'][2, px, py]/255. - ipoints[pts==-1] = 0 - ## Find closest vertices in subsampled mesh. - cVerts, cVertsGT = self.findAllClosestVerts(gt, upoints, vpoints, ipoints) - ## Get pairwise geodesic distances between gt and estimated mesh points. - dist = self.getDistances(cVertsGT, cVerts) - ## Compute the Ogps measure. + ious = np.zeros((len(dt), len(gt))) + + # compute gps between each detection and ground truth object + # sigma = self.sigma # 0.255 # dist = 0.3m corresponds to gps = 0.5 + # 1 # dist = 0.3m corresponds to gps = 0.96 + # 1.45 # dist = 1.7m (person height) corresponds to gps = 0.5) + for j, g in enumerate(gt): + # gps between any detection and an ignored ground truth person is 0 + if g['ignore']: + continue + bb = g['bbox'] + for i, d in enumerate(dt): + # width and height of detected bbox + dx = d['bbox'][2]; dy = d['bbox'][3] + # (2D) spatial coordinates of annotated points within current gt bbox on the image, + # which are scaled such that the gt bbox size is 256 x 256. + dp_x = np.array(g['dp_x']) * bb[2] / 255. + dp_y = np.array(g['dp_y']) * bb[3] / 255. + # spatial coordinates of annotated points relative to detected bbox + px = (dp_x + bb[0] - d['bbox'][0]).astype(np.int) + py = (dp_y + bb[1] - d['bbox'][1]).astype(np.int) + pts = np.zeros(len(dp_x)) # len(dp_x): number of annotated points + # annotated points outside the range of detected bbox are considered as background + pts[px >= dx] = -1; pts[py >= dy] = -1 + pts[px < 0] = -1; pts[py < 0] = -1 + # print("#collected gt points: ", len(dp_x)) + if len(dp_x) == 0: + gps = 0. + elif pts.max() == -1: + gps = 0. + else: + px[pts == -1] = 0; py[pts == -1] = 0; + ipoints = d['uv'][0, py, px] + upoints = d['uv'][1, py, px] / 255. # convert from uint8 by /255. + vpoints = d['uv'][2, py, px] / 255. + ipoints[pts == -1] = 0 + # Find closest vertices index in subsampled mesh. + cVertInds, cVertIndsGT = self.findAllClosestVertInds(g, upoints, vpoints, ipoints) + # Get pairwise geodesic distances between GT and estimated mesh points. + dist = self.getDistances(cVertInds, cVertIndsGT) + # Compute the GPS measure. + if len(dist) > 0: # Find the mean geodesic normalization distance for each GT point, based on which part it is on. - Current_Mean_Distances = self.Mean_Distances[ self.CoarseParts[ self.Part_ids [ cVertsGT[cVertsGT>0].astype(int)-1] ] ] + Current_Mean_Distances = self.Mean_Distances[ + self.CoarseParts[self.Part_ids[cVertIndsGT[cVertIndsGT > 0] - 1]] + ] # Compute gps - ogps_values = np.exp(-(dist**2)/(2*(Current_Mean_Distances**2))) - # - if len(dist)>0: - ogps = np.sum(ogps_values)/ len(dist) - ious[i, j] = ogps + gps = np.exp(-(dist ** 2) / (2 * (Current_Mean_Distances ** 2))) + gps = np.sum(gps) / len(dist) + else: + gps = 0. + ious[i, j] = gps - gbb = [gt['bbox'] for gt in g] - dbb = [dt['bbox'] for dt in d] + gbb = [g['bbox'] for g in gt] + dbb = [d['bbox'] for d in dt] # compute iou between each dt and gt region - iscrowd = [int(o['iscrowd']) for o in g] + iscrowd = [int(o['iscrowd']) for o in gt] ious_bb = maskUtils.iou(dbb, gbb, iscrowd) return ious, ious_bb @@ -423,23 +437,21 @@ def evaluateImg(self, imgId, catId, aRng, maxDet): perform evaluation for single category and image :return: dict (single image results) ''' - p = self.params if p.useCats: - gt = self._gts[imgId,catId] - dt = self._dts[imgId,catId] + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] else: - gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]] - dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]] + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] if len(gt) == 0 and len(dt) == 0: return None for g in gt: - #g['_ignore'] = g['ignore'] - if g['ignore'] or (g['area']aRng[1]): - g['_ignore'] = True + if g['ignore'] or (g['area'] < aRng[0] or g['area'] > aRng[1]): + g['_ignore'] = 1 else: - g['_ignore'] = False + g['_ignore'] = 0 # sort dt highest score first, sort gt ignore last gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort') @@ -449,9 +461,6 @@ def evaluateImg(self, imgId, catId, aRng, maxDet): iscrowd = [int(o['iscrowd']) for o in gt] # load computed ious if p.iouType == 'uv': - #print('Checking the length', len(self.ious[imgId, catId])) - #if len(self.ious[imgId, catId]) == 0: - # print(self.ious[imgId, catId]) ious = self.ious[imgId, catId][0][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId] ioubs = self.ious[imgId, catId][1][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId] else: @@ -460,31 +469,32 @@ def evaluateImg(self, imgId, catId, aRng, maxDet): T = len(p.iouThrs) G = len(gt) D = len(dt) - gtm = np.zeros((T,G)) - dtm = np.zeros((T,D)) + gtm = np.zeros((T, G)) + dtm = np.zeros((T, D)) gtIg = np.array([g['_ignore'] for g in gt]) - dtIg = np.zeros((T,D)) + dtIg = np.zeros((T, D)) if np.all(gtIg) == True and p.iouType == 'uv': dtIg = np.logical_or(dtIg, True) - if len(ious)>0: # and not p.iouType == 'uv': + if len(ious) > 0: for tind, t in enumerate(p.iouThrs): for dind, d in enumerate(dt): # information about best match so far (m=-1 -> unmatched) - iou = min([t,1-1e-10]) + iou = min([t, 1 - 1e-10]) m = -1 for gind, g in enumerate(gt): # if this gt already matched, and not a crowd, continue - if gtm[tind,gind]>0 and not iscrowd[gind]: + if gtm[tind, gind] > 0 and not iscrowd[gind]: continue # if dt matched to reg gt, and on ignore gt, stop - if m>-1 and gtIg[m]==0 and gtIg[gind]==1: + if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1: break # continue to next gt unless better match made - if ious[dind,gind] < iou: - continue - if ious[dind,gind] == 0.: + if ious[dind, gind] < iou: continue + ## redundant condition after the above one + # if ious[dind, gind] == 0.: + # continue # if match successful and best so far, store appropriately iou = ious[dind, gind] m = gind @@ -495,49 +505,54 @@ def evaluateImg(self, imgId, catId, aRng, maxDet): dtm[tind, dind] = gt[m]['id'] gtm[tind, m] = d['id'] - if p.iouType == 'uv': - if not len(ioubs)==0: - for dind, d in enumerate(dt): - # information about best match so far (m=-1 -> unmatched) - if dtm[tind, dind] == 0: - ioub = 0.8 - m = -1 - for gind, g in enumerate(gt): - # if this gt already matched, and not a crowd, continue - if gtm[tind,gind]>0 and not iscrowd[gind]: - continue - # continue to next gt unless better match made - if ioubs[dind,gind] < ioub: - continue - # if match successful and best so far, store appropriately - ioub = ioubs[dind,gind] - m = gind - # if match made store id of match for both dt and gt - if m > -1: - dtIg[:, dind] = gtIg[m] - if gtIg[m]: - dtm[tind, dind] = gt[m]['id'] - gtm[tind, m] = d['id'] + """ + When evaluating for body_uv, suppressing/ignoring a detected box (`dbox`) at all GPS thresholds + which satisfies the following criterion: + GPS(dbox, gbox) = 0 while IoU(dbox, gbox) > 0.8, where `gbox` is an ignored gt box. + """ + if p.iouType == 'uv' and len(ioubs) > 0: + for dind, d in enumerate(dt): + # information about best match so far (m=-1 -> unmatched) + if dtm[tind, dind] == 0: + ioub = 0.8 # a manually set IoU threshold + m = -1 + for gind, g in enumerate(gt): + # if this gt already matched, and not a crowd, continue + if gtm[tind, gind] > 0 and not iscrowd[gind]: + continue + # continue to next gt unless better match made + if ioubs[dind, gind] < ioub: + continue + # if match successful and best so far, store appropriately + ioub = ioubs[dind, gind] + m = gind + # if match made store id of match for both dt and gt + if m == -1: + continue + dtIg[:, dind] = gtIg[m] + if gtIg[m]: + dtm[tind, dind] = gt[m]['id'] + gtm[tind, m] = d['id'] + # set unmatched detections outside of area range to ignore - a = np.array([d['area']aRng[1] for d in dt]).reshape((1, len(dt))) - dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0))) + a = np.array([d['area'] < aRng[0] or d['area'] > aRng[1] for d in dt]).reshape((1, len(dt))) + dtIg = np.logical_or(dtIg, np.logical_and(dtm == 0, np.repeat(a, T, 0))) # store results for given image and category - #print('Done with the function', len(self.ious[imgId, catId])) return { - 'image_id': imgId, - 'category_id': catId, - 'aRng': aRng, - 'maxDet': maxDet, - 'dtIds': [d['id'] for d in dt], - 'gtIds': [g['id'] for g in gt], - 'dtMatches': dtm, - 'gtMatches': gtm, - 'dtScores': [d['score'] for d in dt], - 'gtIgnore': gtIg, - 'dtIgnore': dtIg, - } - - def accumulate(self, p = None): + 'image_id': imgId, + 'category_id': catId, + 'aRng': aRng, + 'maxDet': maxDet, + 'dtIds': [d['id'] for d in dt], + 'gtIds': [g['id'] for g in gt], + 'dtMatches': dtm, + 'gtMatches': gtm, + 'dtScores': [d['score'] for d in dt], + 'gtIgnore': gtIg, + 'dtIgnore': dtIg, + } + + def accumulate(self, p=None): ''' Accumulate per image evaluation results and store the result in self.eval :param p: input params for evaluation @@ -551,16 +566,17 @@ def accumulate(self, p = None): if p is None: p = self.params p.catIds = p.catIds if p.useCats == 1 else [-1] - T = len(p.iouThrs) - R = len(p.recThrs) - K = len(p.catIds) if p.useCats else 1 - A = len(p.areaRng) - M = len(p.maxDets) - precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories - recall = -np.ones((T,K,A,M)) + T = len(p.iouThrs) + R = len(p.recThrs) + K = len(p.catIds) if p.useCats else 1 + A = len(p.areaRng) + M = len(p.maxDets) + precision = -np.ones((T, R, K, A, M)) # -1 for the precision of absent categories + recall = -np.ones((T, K, A, M)) + scores = -np.ones((T, R, K, A, M)) # create dictionary for future indexing - print('Categories:', p.catIds) + print('Categories ids:', p.catIds) _pe = self._paramsEval catIds = _pe.catIds if _pe.useCats else [-1] setK = set(catIds) @@ -589,58 +605,61 @@ def accumulate(self, p = None): # different sorting method generates slightly different results. # mergesort is used to be consistent as Matlab implementation. inds = np.argsort(-dtScores, kind='mergesort') + dtScoresSorted = dtScores[inds] - dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds] - dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds] + dtm = np.concatenate([e['dtMatches'][:, 0:maxDet] for e in E], axis=1)[:, inds] + dtIg = np.concatenate([e['dtIgnore'][:, 0:maxDet] for e in E], axis=1)[:, inds] gtIg = np.concatenate([e['gtIgnore'] for e in E]) - npig = np.count_nonzero(gtIg==0) - #print('DTIG', np.sum(np.logical_not(dtIg)), len(dtIg)) - #print('GTIG', np.sum(np.logical_not(gtIg)), len(gtIg)) + npig = np.count_nonzero(gtIg == 0) if npig == 0: continue - tps = np.logical_and( dtm, np.logical_not(dtIg)) + tps = np.logical_and( dtm, np.logical_not(dtIg)) fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg)) + tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float) fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float) - #print('TP_SUM', tp_sum, 'FP_SUM', fp_sum) for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): tp = np.array(tp) fp = np.array(fp) nd = len(tp) rc = tp / npig - pr = tp / (fp+tp+np.spacing(1)) + pr = tp / (fp + tp + np.spacing(1)) q = np.zeros((R,)) + ss = np.zeros((R,)) if nd: - recall[t,k,a,m] = rc[-1] + recall[t, k, a, m] = rc[-1] else: - recall[t,k,a,m] = 0 + recall[t, k, a, m] = 0 # numpy is slow without cython optimization for accessing elements # use python array gets significant speed improvement pr = pr.tolist(); q = q.tolist() - for i in range(nd-1, 0, -1): - if pr[i] > pr[i-1]: - pr[i-1] = pr[i] + for i in range(nd - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] inds = np.searchsorted(rc, p.recThrs, side='left') try: for ri, pi in enumerate(inds): q[ri] = pr[pi] + ss[ri] = dtScoresSorted[pi] except: pass - precision[t,:,k,a,m] = np.array(q) - print('Final', np.max(precision), np.min(precision)) + precision[t, :, k, a, m] = np.array(q) + scores[t, :, k, a, m] = np.array(ss) + print('Final precisions, max: {:.2f}, min: {:.2f}'.format(np.max(precision), np.min(precision))) self.eval = { 'params': p, 'counts': [T, R, K, A, M], 'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'precision': precision, - 'recall': recall, + 'recall': recall, + 'scores': scores, } toc = time.time() - print('DONE (t={:0.2f}s).'.format( toc-tic)) + print('DONE (t={:0.2f}s).'.format(toc - tic)) def summarize(self): ''' @@ -651,12 +670,12 @@ def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ): p = self.params iStr = ' {:<18} {} @[ {}={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}' titleStr = 'Average Precision' if ap == 1 else 'Average Recall' - typeStr = '(AP)' if ap==1 else '(AR)' + typeStr = '(AP)' if ap == 1 else '(AR)' measure = 'IoU' if self.params.iouType == 'keypoints': measure = 'OKS' - elif self.params.iouType =='uv': - measure = 'OGPS' + elif self.params.iouType == 'uv': + measure = 'GPS' iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \ if iouThr is None else '{:0.2f}'.format(iouThr) @@ -667,22 +686,23 @@ def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ): s = self.eval['precision'] # IoU if iouThr is not None: - t = np.where(np.abs(iouThr - p.iouThrs)<0.001)[0] + t = np.where(np.abs(iouThr - p.iouThrs) < 0.001)[0] s = s[t] - s = s[:,:,:,aind,mind] + s = s[:, :, :, aind, mind] else: # dimension of recall: [TxKxAxM] s = self.eval['recall'] if iouThr is not None: t = np.where(iouThr == p.iouThrs)[0] s = s[t] - s = s[:,:,aind,mind] - if len(s[s>-1])==0: + s = s[:, :, aind, mind] + if len(s[s > -1]) == 0: mean_s = -1 else: - mean_s = np.mean(s[s>-1]) + mean_s = np.mean(s[s > -1]) print(iStr.format(titleStr, typeStr, measure, iouStr, areaRng, maxDets, mean_s)) return mean_s + def _summarizeDets(): stats = np.zeros((12,)) stats[0] = _summarize(1) @@ -698,6 +718,7 @@ def _summarizeDets(): stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2]) stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2]) return stats + def _summarizeKps(): stats = np.zeros((10,)) stats[0] = _summarize(1, maxDets=20) @@ -711,6 +732,7 @@ def _summarizeKps(): stats[8] = _summarize(0, maxDets=20, areaRng='medium') stats[9] = _summarize(0, maxDets=20, areaRng='large') return stats + def _summarizeUvs(): stats = np.zeros((18,)) stats[0] = _summarize(1, maxDets=self.params.maxDets[0]) @@ -732,6 +754,7 @@ def _summarizeUvs(): stats[16] = _summarize(0, maxDets=self.params.maxDets[0], areaRng='medium') stats[17] = _summarize(0, maxDets=self.params.maxDets[0], areaRng='large') return stats + if not self.eval: raise Exception('Please run accumulate() first') iouType = self.params.iouType @@ -747,74 +770,51 @@ def __str__(self): self.summarize() # ================ functions for dense pose ============================== - def findAllClosestVerts(self, gt, U_points, V_points, Index_points): - # + def findAllClosestVertInds(self, gt, U_points, V_points, Index_points): I_gt = np.array(gt['dp_I']) U_gt = np.array(gt['dp_U']) V_gt = np.array(gt['dp_V']) - # - #print(I_gt) - # - ClosestVerts = np.ones(Index_points.shape)*-1 - for i in np.arange(24): - # - if sum(Index_points == (i+1))>0: - UVs = np.array( [U_points[Index_points == (i+1)],V_points[Index_points == (i+1)]]) - Current_Part_UVs = self.Part_UVs[i] - Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i] - D = ssd.cdist( Current_Part_UVs.transpose(), UVs.transpose()).squeeze() - ClosestVerts[Index_points == (i+1)] = Current_Part_ClosestVertInds[ np.argmin(D,axis=0) ] - # - ClosestVertsGT = np.ones(Index_points.shape)*-1 - for i in np.arange(24): - if sum(I_gt==(i+1))>0: - UVs = np.array([ - U_gt[I_gt==(i+1)], - V_gt[I_gt==(i+1)] - ]) - Current_Part_UVs = self.Part_UVs[i] - Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i] - D = ssd.cdist( Current_Part_UVs.transpose(), UVs.transpose()).squeeze() - ClosestVertsGT[I_gt==(i+1)] = Current_Part_ClosestVertInds[ np.argmin(D,axis=0) ] - # - return ClosestVerts, ClosestVertsGT - - - def getDistances(self, cVertsGT, cVerts): - - ClosestVertsTransformed = self.PDIST_transform[cVerts.astype(int)-1] - ClosestVertsGTTransformed = self.PDIST_transform[cVertsGT.astype(int)-1] - # - ClosestVertsTransformed[cVerts<0] = 0 - ClosestVertsGTTransformed[cVertsGT<0] = 0 - # - cVertsGT = ClosestVertsGTTransformed - cVerts = ClosestVertsTransformed - # - n = 27554 + # find closest vertex for each estimated point and gt point in each part + ClosestVertInds = np.ones(Index_points.shape, dtype=int) * -1 + ClosestVertIndsGT = np.ones(Index_points.shape, dtype=int) * -1 + for i in np.arange(1, self.num_parts + 1): + Current_Part_UVs = self.Part_UVs[i - 1] + Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i - 1] + if sum(Index_points == i) > 0: + UVs = np.array([U_points[Index_points == i], V_points[Index_points == i]]) + D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze() + ClosestVertInds[Index_points == i] = Current_Part_ClosestVertInds[np.argmin(D, axis=0)] + if sum(I_gt == i) > 0: + UVs = np.array([U_gt[I_gt == i], V_gt[I_gt == i]]) + D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze() + ClosestVertIndsGT[I_gt == i] = Current_Part_ClosestVertInds[np.argmin(D, axis=0)] + + return ClosestVertInds, ClosestVertIndsGT + + + def getDistances(self, cVertInds, cVertIndsGT): + cVerts = self.PDIST_transform[cVertInds - 1] + cVertsGT = self.PDIST_transform[cVertIndsGT - 1] + cVerts[cVertInds < 0] = 0 + cVertsGT[cVertIndsGT < 0] = 0 + # `n` is the number of points of which the geodesic distances are precomputed + # n = 27554 dists = [] + # loop through each pair of (gt_point, dt_point) for d in range(len(cVertsGT)): if cVertsGT[d] > 0: if cVerts[d] > 0: i = cVertsGT[d] - 1 j = cVerts[d] - 1 - if j == i: + if i == j: # elements on the diagonal are all zeros dists.append(0) - elif j > i: - ccc = i - i = j - j = ccc - i = n-i-1 - j = n-j-1 - k = (n*(n-1)/2) - (n-i)*((n-i)-1)/2 + j - i - 1 - k = ( n*n - n )/2 -k -1 - dists.append(self.Pdist_matrix[int(k)][0]) + elif i > j: + # find the offset to fetch the precomputed geodesic distance + k = i * (i - 1) / 2 + j + dists.append(self.Pdist_matrix[k][0]) else: - i= n-i-1 - j= n-j-1 - k = (n*(n-1)/2) - (n-i)*((n-i)-1)/2 + j - i - 1 - k = ( n*n - n )/2 -k -1 - dists.append(self.Pdist_matrix[int(k)][0]) + k = j * (j - 1) / 2 + i + dists.append(self.Pdist_matrix[k][0]) else: dists.append(np.inf) return np.array(dists).squeeze() diff --git a/detectron/datasets/json_dataset.py b/detectron/datasets/json_dataset.py index c6757e5..f65117f 100644 --- a/detectron/datasets/json_dataset.py +++ b/detectron/datasets/json_dataset.py @@ -154,7 +154,7 @@ def _prep_roidb_entry(self, entry): (0, 3, self.num_keypoints), dtype=np.int32 ) if cfg.MODEL.BODY_UV_ON: - entry['ignore_UV_body'] = np.empty((0), dtype=np.bool) + entry['ignore_UV_body'] = np.empty((0), dtype=np.bool) # entry['Box_image_links_body'] = [] # Remove unwanted fields that come from the json file (if they exist) for k in ['date_captured', 'url', 'license', 'file_name']: @@ -200,7 +200,7 @@ def _add_gt_annotations(self, entry): valid_objs.append(obj) valid_segms.append(obj['segmentation']) ### - if 'dp_x' in obj.keys(): + if 'dp_x' in obj: valid_dp_x.append(obj['dp_x']) valid_dp_y.append(obj['dp_y']) valid_dp_I.append(obj['dp_I']) @@ -216,7 +216,7 @@ def _add_gt_annotations(self, entry): valid_dp_masks.append([]) ### num_valid_objs = len(valid_objs) - ## + boxes = np.zeros((num_valid_objs, 4), dtype=entry['boxes'].dtype) gt_classes = np.zeros((num_valid_objs), dtype=entry['gt_classes'].dtype) gt_overlaps = np.zeros( @@ -234,7 +234,7 @@ def _add_gt_annotations(self, entry): dtype=entry['gt_keypoints'].dtype ) if cfg.MODEL.BODY_UV_ON: - ignore_UV_body = np.zeros((num_valid_objs)) + ignore_UV_body = np.zeros((num_valid_objs), dtype=entry['ignore_UV_body'].dtype) #Box_image_body = [None]*num_valid_objs im_has_visible_keypoints = False diff --git a/detectron/datasets/json_dataset_evaluator.py b/detectron/datasets/json_dataset_evaluator.py index a950973..4ff94df 100644 --- a/detectron/datasets/json_dataset_evaluator.py +++ b/detectron/datasets/json_dataset_evaluator.py @@ -213,9 +213,9 @@ def _get_thr_ind(coco_eval, thr): precision = coco_eval.eval['precision'][ind_lo:(ind_hi + 1), :, :, 0, 2] ap_default = np.mean(precision[precision > -1]) logger.info( - '~~~~ Mean and per-category AP @ IoU=[{:.2f},{:.2f}] ~~~~'.format( + '~~~~ Mean and per-category AP @ IoU=[{:.2f}, {:.2f}] ~~~~'.format( IoU_lo_thresh, IoU_hi_thresh)) - logger.info('{:.1f}'.format(100 * ap_default)) + logger.info('mAP: {:.1f}'.format(100 * ap_default)) for cls_ind, cls in enumerate(json_dataset.classes): if cls == '__background__': continue @@ -223,7 +223,7 @@ def _get_thr_ind(coco_eval, thr): precision = coco_eval.eval['precision'][ ind_lo:(ind_hi + 1), :, cls_ind - 1, 0, 2] ap = np.mean(precision[precision > -1]) - logger.info('{:.1f}'.format(100 * ap)) + logger.info("class '{}' AP: {:.1f}".format(cls, 100 * ap)) logger.info('~~~~ Summary metrics ~~~~') coco_eval.summarize() @@ -439,20 +439,17 @@ def evaluate_body_uv( if use_salt: res_file += '_{}'.format(str(uuid.uuid4())) res_file += '.pkl' - results = _write_coco_body_uv_results_file( + _write_coco_body_uv_results_file( json_dataset, all_boxes, all_bodys, res_file ) # Only do evaluation on non-test sets (annotations are undisclosed on test) if json_dataset.name.find('test') == -1: - # See comment in _write_coco_body_uv_results_file - #coco_eval = _do_body_uv_eval(json_dataset, res_file, output_dir) - coco_eval = _do_body_uv_eval(json_dataset, results, output_dir) + coco_eval = _do_body_uv_eval(json_dataset, res_file, output_dir) else: coco_eval = None - # See comment in _write_coco_body_uv_results_file # Optionally cleanup results json file - #if cleanup: - # os.remove(res_file) + if cleanup: + os.remove(res_file) return coco_eval @@ -460,7 +457,7 @@ def _write_coco_body_uv_results_file( json_dataset, all_boxes, all_bodys, res_file ): results = [] - for cls_ind,cls in enumerate(json_dataset.classes): + for cls_ind, cls in enumerate(json_dataset.classes): if cls == '__background__': continue if cls_ind >= len(all_bodys): @@ -473,20 +470,12 @@ def _write_coco_body_uv_results_file( json_dataset, all_boxes[cls_ind], all_bodys[cls_ind], cat_id)) # Body UV results are stored in 3xHxW ndarray format, # which is not json serializable - #logger.info( - # 'Writing body uv results json to: {}'.format( - # os.path.abspath(res_file))) - #with open(res_file, 'w') as fid: - # json.dump(results, fid) - # logger.info( 'Writing body uv results pkl to: {}'.format( os.path.abspath(res_file))) - with open(res_file, 'wb') as f: pickle.dump(results, f, pickle.HIGHEST_PROTOCOL) - #logger.info('Not writing body uv resuts json') - return res_file + # logger.info('Not writing body uv resuts json') def _coco_body_uv_results_one_category(json_dataset, boxes, body_uvs, cat_id): @@ -502,6 +491,9 @@ def _coco_body_uv_results_one_category(json_dataset, boxes, body_uvs, cat_id): uv_dets = body_uvs[i] box_dets = boxes[i].astype(np.float) scores = box_dets[:, -1] + # Convert the uv fields to uint8. + for uv in uv_dets: + uv[1:3, :, :] = uv[1:3, :, :] * 255 # Don't use xyxy_to_xywh function for consistency with the original imp # Instead, cast to ints and don't add 1 when computing ws and hs # xywh_box_dets = box_utils.xyxy_to_xywh(box_dets[:, 0:4]) @@ -509,11 +501,6 @@ def _coco_body_uv_results_one_category(json_dataset, boxes, body_uvs, cat_id): # ys = xywh_box_dets[:, 1] # ws = xywh_box_dets[:, 2] # hs = xywh_box_dets[:, 3] - - # Convert the uv fields to uint8. - for uv in uv_dets: - uv[1:3,:,:] = uv[1:3,:,:]*255 - ### xs = box_dets[:, 0] ys = box_dets[:, 1] ws = (box_dets[:, 2] - xs).astype(np.int) @@ -533,17 +520,18 @@ def _do_body_uv_eval(json_dataset, res_file, output_dir): imgIds = json_dataset.COCO.getImgIds() imgIds.sort() with open(res_file, 'rb') as f: - res=pickle.load(f) + res = pickle.load(f) coco_dt = json_dataset.COCO.loadRes(res) # Non-standard params used by the modified COCO API version # from the DensePose fork + # global normalization factor used in per-instance evaluation test_sigma = 0.255 coco_eval = denseposeCOCOeval(json_dataset.COCO, coco_dt, ann_type, test_sigma) coco_eval.params.imgIds = imgIds coco_eval.evaluate() coco_eval.accumulate() - #eval_file = os.path.join(output_dir, 'body_uv_results.pkl') - #save_object(coco_eval, eval_file) - #logger.info('Wrote json eval results to: {}'.format(eval_file)) + # eval_file = os.path.join(output_dir, 'body_uv_results.pkl') + # save_object(coco_eval, eval_file) + # logger.info('Wrote json eval results to: {}'.format(eval_file)) coco_eval.summarize() return coco_eval diff --git a/detectron/datasets/roidb.py b/detectron/datasets/roidb.py index e37e3d1..83382f2 100644 --- a/detectron/datasets/roidb.py +++ b/detectron/datasets/roidb.py @@ -121,7 +121,7 @@ def is_valid(entry): if cfg.MODEL.BODY_UV_ON and cfg.BODY_UV_RCNN.BODY_UV_IMS: # Exclude images with no body uv valid = valid and entry['has_body_uv'] - return valid + return valid num = len(roidb) filtered_roidb = [entry for entry in roidb if is_valid(entry)] diff --git a/detectron/datasets/task_evaluation.py b/detectron/datasets/task_evaluation.py index bdae2e6..0d076c7 100644 --- a/detectron/datasets/task_evaluation.py +++ b/detectron/datasets/task_evaluation.py @@ -47,7 +47,7 @@ def evaluate_all( output_dir, use_matlab=False ): """Evaluate "all" tasks, where "all" includes box detection, instance - segmentation, and keypoint detection. + segmentation, keypoint detection, and human dense pose estimation. """ all_results = evaluate_boxes( dataset, all_boxes, output_dir, use_matlab=use_matlab diff --git a/detectron/modeling/body_uv_rcnn_heads.py b/detectron/modeling/body_uv_rcnn_heads.py index dd67f43..f95aef5 100644 --- a/detectron/modeling/body_uv_rcnn_heads.py +++ b/detectron/modeling/body_uv_rcnn_heads.py @@ -3,7 +3,22 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -# +############################################################################## + +"""Various network "heads" for dense human pose estimation in DensePose. + +The design is as follows: + +... -> RoI ----\ /-> mask output -> cls loss + -> RoIFeatureXform -> body UV head -> patch output -> cls loss +... -> Feature / \-> UV output -> reg loss + Map + +The body UV head produces a feature representation of the RoI for the purpose +of dense semantic mask prediction, body surface patch prediction and body UV +coordinates regression. The body UV output module converts the feature +representation into heatmaps for dense mask, patch index and UV coordinates. +""" from __future__ import absolute_import from __future__ import division @@ -11,142 +26,123 @@ from __future__ import unicode_literals from caffe2.python import core - from detectron.core.config import cfg - +from detectron.utils.c2 import const_fill import detectron.modeling.ResNet as ResNet import detectron.utils.blob as blob_utils # ---------------------------------------------------------------------------- # -# Body UV heads +# Body UV outputs and losses # ---------------------------------------------------------------------------- # -def add_body_uv_outputs(model, blob_in, dim, pref=''): - #### - model.ConvTranspose(blob_in, 'AnnIndex_lowres'+pref, dim, 15,cfg.BODY_UV_RCNN.DECONV_KERNEL, pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1), stride=2, weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}), bias_init=('ConstantFill', {'value': 0.})) - #### - model.ConvTranspose(blob_in, 'Index_UV_lowres'+pref, dim, cfg.BODY_UV_RCNN.NUM_PATCHES+1,cfg.BODY_UV_RCNN.DECONV_KERNEL, pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1), stride=2, weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}), bias_init=('ConstantFill', {'value': 0.})) - #### - model.ConvTranspose( - blob_in, 'U_lowres'+pref, dim, (cfg.BODY_UV_RCNN.NUM_PATCHES+1), - cfg.BODY_UV_RCNN.DECONV_KERNEL, - pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1), - stride=2, - weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}), - bias_init=('ConstantFill', {'value': 0.})) - ##### - model.ConvTranspose( - blob_in, 'V_lowres'+pref, dim, cfg.BODY_UV_RCNN.NUM_PATCHES+1, +def add_body_uv_outputs(model, blob_in, dim): + """Add DensePose body UV specific outputs: heatmaps of dense mask, patch index + and patch-specific UV coordinates. All dense masks are mapped to labels in + [0, ... S] for S semantically meaningful body parts. + """ + # Apply ConvTranspose to the feature representation; results in 2x upsampling + for name in ['AnnIndex', 'Index_UV', 'U', 'V']: + if name == 'AnnIndex': + dim_out = cfg.BODY_UV_RCNN.NUM_SEMANTIC_PARTS + 1 + else: + dim_out = cfg.BODY_UV_RCNN.NUM_PATCHES + 1 + model.ConvTranspose( + blob_in, + name + '_lowres', + dim, + dim_out, cfg.BODY_UV_RCNN.DECONV_KERNEL, pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1), stride=2, weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}), - bias_init=('ConstantFill', {'value': 0.})) - #### - blob_Ann_Index = model.BilinearInterpolation('AnnIndex_lowres'+pref, 'AnnIndex'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE) - blob_Index = model.BilinearInterpolation('Index_UV_lowres'+pref, 'Index_UV'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE) - blob_U = model.BilinearInterpolation('U_lowres'+pref, 'U_estimated'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE) - blob_V = model.BilinearInterpolation('V_lowres'+pref, 'V_estimated'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE) - ### - return blob_U,blob_V,blob_Index,blob_Ann_Index - - -def add_body_uv_losses(model, pref=''): - - ## Reshape for GT blobs. - model.net.Reshape( ['body_uv_X_points'], ['X_points_reshaped'+pref, 'X_points_shape'+pref], shape=( -1 ,1 ) ) - model.net.Reshape( ['body_uv_Y_points'], ['Y_points_reshaped'+pref, 'Y_points_shape'+pref], shape=( -1 ,1 ) ) - model.net.Reshape( ['body_uv_I_points'], ['I_points_reshaped'+pref, 'I_points_shape'+pref], shape=( -1 ,1 ) ) - model.net.Reshape( ['body_uv_Ind_points'], ['Ind_points_reshaped'+pref, 'Ind_points_shape'+pref], shape=( -1 ,1 ) ) - ## Concat Ind,x,y to get Coordinates blob. - model.net.Concat( ['Ind_points_reshaped'+pref,'X_points_reshaped'+pref, \ - 'Y_points_reshaped'+pref],['Coordinates'+pref,'Coordinate_Shapes'+pref ], axis = 1 ) - ## - ### Now reshape UV blobs, such that they are 1x1x(196*NumSamples)xNUM_PATCHES - ## U blob to - ## - model.net.Reshape(['body_uv_U_points'], \ - ['U_points_reshaped'+pref, 'U_points_old_shape'+pref],\ - shape=(-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1,196)) - model.net.Transpose(['U_points_reshaped'+pref] ,['U_points_reshaped_transpose'+pref],axes=(0,2,1) ) - model.net.Reshape(['U_points_reshaped_transpose'+pref], \ - ['U_points'+pref, 'U_points_old_shape2'+pref], \ - shape=(1,1,-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1)) - ## V blob - ## - model.net.Reshape(['body_uv_V_points'], \ - ['V_points_reshaped'+pref, 'V_points_old_shape'+pref],\ - shape=(-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1,196)) - model.net.Transpose(['V_points_reshaped'+pref] ,['V_points_reshaped_transpose'+pref],axes=(0,2,1) ) - model.net.Reshape(['V_points_reshaped_transpose'+pref], \ - ['V_points'+pref, 'V_points_old_shape2'+pref], \ - shape=(1,1,-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1)) - ### - ## UV weights blob - ## - model.net.Reshape(['body_uv_point_weights'], \ - ['Uv_point_weights_reshaped'+pref, 'Uv_point_weights_old_shape'+pref],\ - shape=(-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1,196)) - model.net.Transpose(['Uv_point_weights_reshaped'+pref] ,['Uv_point_weights_reshaped_transpose'+pref],axes=(0,2,1) ) - model.net.Reshape(['Uv_point_weights_reshaped_transpose'+pref], \ - ['Uv_point_weights'+pref, 'Uv_point_weights_old_shape2'+pref], \ - shape=(1,1,-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1)) - - ##################### - ### Pool IUV for points via bilinear interpolation. - model.PoolPointsInterp(['U_estimated','Coordinates'+pref], ['interp_U'+pref]) - model.PoolPointsInterp(['V_estimated','Coordinates'+pref], ['interp_V'+pref]) - model.PoolPointsInterp(['Index_UV'+pref,'Coordinates'+pref], ['interp_Index_UV'+pref]) - - ## Reshape interpolated UV coordinates to apply the loss. - - model.net.Reshape(['interp_U'+pref], \ - ['interp_U_reshaped'+pref, 'interp_U_shape'+pref],\ - shape=(1, 1, -1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1)) - - model.net.Reshape(['interp_V'+pref], \ - ['interp_V_reshaped'+pref, 'interp_V_shape'+pref],\ - shape=(1, 1, -1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1)) - ### - - ### Do the actual labels here !!!! - model.net.Reshape( ['body_uv_ann_labels'], \ - ['body_uv_ann_labels_reshaped' +pref, 'body_uv_ann_labels_old_shape'+pref], \ - shape=(-1, cfg.BODY_UV_RCNN.HEATMAP_SIZE , cfg.BODY_UV_RCNN.HEATMAP_SIZE)) - - model.net.Reshape( ['body_uv_ann_weights'], \ - ['body_uv_ann_weights_reshaped' +pref, 'body_uv_ann_weights_old_shape'+pref], \ - shape=( -1 , cfg.BODY_UV_RCNN.HEATMAP_SIZE , cfg.BODY_UV_RCNN.HEATMAP_SIZE)) - ### - model.net.Cast( ['I_points_reshaped'+pref], ['I_points_reshaped_int'+pref], to=core.DataType.INT32) - ### Now add the actual losses - ## The mask segmentation loss (dense) - probs_seg_AnnIndex, loss_seg_AnnIndex = model.net.SpatialSoftmaxWithLoss( \ - ['AnnIndex'+pref, 'body_uv_ann_labels_reshaped'+pref,'body_uv_ann_weights_reshaped'+pref],\ - ['probs_seg_AnnIndex'+pref,'loss_seg_AnnIndex'+pref], \ - scale=cfg.BODY_UV_RCNN.INDEX_WEIGHTS / cfg.NUM_GPUS) - ## Point Patch Index Loss. - probs_IndexUVPoints, loss_IndexUVPoints = model.net.SoftmaxWithLoss(\ - ['interp_Index_UV'+pref,'I_points_reshaped_int'+pref],\ - ['probs_IndexUVPoints'+pref,'loss_IndexUVPoints'+pref], \ - scale=cfg.BODY_UV_RCNN.PART_WEIGHTS / cfg.NUM_GPUS, spatial=0) - ## U and V point losses. - loss_Upoints = model.net.SmoothL1Loss( \ - ['interp_U_reshaped'+pref, 'U_points'+pref, \ - 'Uv_point_weights'+pref, 'Uv_point_weights'+pref], \ - 'loss_Upoints'+pref, \ - scale=cfg.BODY_UV_RCNN.POINT_REGRESSION_WEIGHTS / cfg.NUM_GPUS) - - loss_Vpoints = model.net.SmoothL1Loss( \ - ['interp_V_reshaped'+pref, 'V_points'+pref, \ - 'Uv_point_weights'+pref, 'Uv_point_weights'+pref], \ - 'loss_Vpoints'+pref, scale=cfg.BODY_UV_RCNN.POINT_REGRESSION_WEIGHTS / cfg.NUM_GPUS) - ## Add the losses. - loss_gradients = blob_utils.get_loss_gradients(model, \ - [ loss_Upoints, loss_Vpoints, loss_seg_AnnIndex, loss_IndexUVPoints]) - model.losses = list(set(model.losses + \ - ['loss_Upoints'+pref , 'loss_Vpoints'+pref , \ - 'loss_seg_AnnIndex'+pref ,'loss_IndexUVPoints'+pref])) + bias_init=const_fill(0.0) + ) + # Increase heatmap output size via bilinear upsampling + blob_outputs = [] + for name in ['AnnIndex', 'Index_UV', 'U', 'V']: + blob_outputs.append( + model.BilinearInterpolation( + name + '_lowres', + name + '_estimated' if name in ['U', 'V'] else name, + cfg.BODY_UV_RCNN.NUM_PATCHES + 1, + cfg.BODY_UV_RCNN.NUM_PATCHES + 1, + cfg.BODY_UV_RCNN.UP_SCALE + ) + ) + + return blob_outputs + + +def add_body_uv_losses(model): + """Add DensePose body UV specific losses.""" + # Pool estimated IUV points via bilinear interpolation. + for name in ['U', 'V', 'Index_UV']: + model.PoolPointsInterp( + [ + name + '_estimated' if name in ['U', 'V'] else name, + 'body_uv_coords_xy' + ], + ['interp_' + name] + ) + + # Compute spatial softmax normalized probabilities, after which + # cross-entropy loss is computed for semantic parts classification. + probs_AnnIndex, loss_AnnIndex = model.net.SpatialSoftmaxWithLoss( + [ + 'AnnIndex', + 'body_uv_parts', 'body_uv_parts_weights' + ], + ['probs_AnnIndex', 'loss_AnnIndex'], + scale=model.GetLossScale() * cfg.BODY_UV_RCNN.INDEX_WEIGHTS + ) + # Softmax loss for surface patch classification. + probs_I_points, loss_I_points = model.net.SoftmaxWithLoss( + ['interp_Index_UV', 'body_uv_I_points'], + ['probs_I_points', 'loss_I_points'], + scale=model.GetLossScale() * cfg.BODY_UV_RCNN.PART_WEIGHTS, + spatial=0 + ) + ## Smooth L1 loss for each patch-specific UV coordinates regression. + # Reshape U,V blobs of both interpolated and ground-truth to compute + # summarized (instead of averaged) SmoothL1Loss. + loss_UV = list() + model.net.Reshape( + ['body_uv_point_weights'], + ['UV_point_weights', 'body_uv_point_weights_shape'], + shape=(1, -1, cfg.BODY_UV_RCNN.NUM_PATCHES + 1) + ) + for name in ['U', 'V']: + # Reshape U/V coordinates of both interpolated points and ground-truth + # points from (#points, #patches) to (1, #points, #patches). + model.net.Reshape( + ['body_uv_' + name + '_points'], + [name + '_points', 'body_uv_' + name + '_points_shape'], + shape=(1, -1, cfg.BODY_UV_RCNN.NUM_PATCHES + 1) + ) + model.net.Reshape( + ['interp_' + name], + ['interp_' + name + '_reshaped', 'interp_' + name + 'shape'], + shape=(1, -1, cfg.BODY_UV_RCNN.NUM_PATCHES + 1) + ) + # Compute summarized SmoothL1Loss of all points. + loss_UV.append( + model.net.SmoothL1Loss( + [ + 'interp_' + name + '_reshaped', name + '_points', + 'UV_point_weights', 'UV_point_weights' + ], + 'loss_' + name + '_points', + scale=model.GetLossScale() * cfg.BODY_UV_RCNN.POINT_REGRESSION_WEIGHTS + ) + ) + # Add all losses to compute gradients + loss_gradients = blob_utils.get_loss_gradients( + model, [loss_AnnIndex, loss_I_points] + loss_UV + ) + # Update model training losses + model.AddLosses( + ['loss_' + name for name in ['AnnIndex', 'I_points', 'U_points', 'V_points']] + ) return loss_gradients @@ -155,17 +151,17 @@ def add_body_uv_losses(model, pref=''): # Body UV heads # ---------------------------------------------------------------------------- # -def add_ResNet_roi_conv5_head_for_bodyUV( - model, blob_in, dim_in, spatial_scale -): +def add_ResNet_roi_conv5_head_for_bodyUV(model, blob_in, dim_in, spatial_scale): """Add a ResNet "conv5" / "stage5" head for body UV prediction.""" model.RoIFeatureTransform( - blob_in, '_[body_uv]_pool5', + blob_in, + '_[body_uv]_pool5', blob_rois='body_uv_rois', method=cfg.BODY_UV_RCNN.ROI_XFORM_METHOD, resolution=cfg.BODY_UV_RCNN.ROI_XFORM_RESOLUTION, sampling_ratio=cfg.BODY_UV_RCNN.ROI_XFORM_SAMPLING_RATIO, - spatial_scale=spatial_scale) + spatial_scale=spatial_scale + ) # Using the prefix '_[body_uv]_' to 'res5' enables initializing the head's # parameters using pretrained 'res5' parameters if given (see # utils.net.initialize_from_weights_file) @@ -184,7 +180,7 @@ def add_ResNet_roi_conv5_head_for_bodyUV( def add_roi_body_uv_head_v1convX(model, blob_in, dim_in, spatial_scale): - """v1convX design: X * (conv).""" + """Add a DensePose body UV head. v1convX design: X * (conv).""" hidden_dim = cfg.BODY_UV_RCNN.CONV_HEAD_DIM kernel_size = cfg.BODY_UV_RCNN.CONV_HEAD_KERNEL pad_size = kernel_size // 2 @@ -208,7 +204,7 @@ def add_roi_body_uv_head_v1convX(model, blob_in, dim_in, spatial_scale): stride=1, pad=pad_size, weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.01}), - bias_init=('ConstantFill', {'value': 0.}) + bias_init=const_fill(0.0) ) current = model.Relu(current, current) dim_in = hidden_dim diff --git a/detectron/modeling/model_builder.py b/detectron/modeling/model_builder.py index 35f9f2c..0eb3899 100644 --- a/detectron/modeling/model_builder.py +++ b/detectron/modeling/model_builder.py @@ -329,7 +329,7 @@ def _add_roi_body_uv_head( model, add_roi_body_uv_head_func, blob_in, dim_in, spatial_scale_in ): """Add a body UV prediction head to the model.""" - # Capture model graph before adding the mask head + # Capture model graph before adding the body UV head bbox_net = copy.deepcopy(model.net.Proto()) # Add the body UV head blob_body_uv_head, dim_body_uv_head = add_roi_body_uv_head_func( @@ -343,7 +343,7 @@ def _add_roi_body_uv_head( if not model.train: # == inference # Inference uses a cascade of box predictions, then body uv predictions # This requires separate nets for box and body uv prediction. - # So we extract the keypoint prediction net, store it as its own + # So we extract the body uv prediction net, store it as its own # network, then restore model.net to be the bbox-only network model.body_uv_net, body_uv_blob_out = c2_utils.SuffixNet( 'body_uv_net', model.net, len(bbox_net.op), blobs_body_uv diff --git a/detectron/ops/pool_points_interp.cc b/detectron/ops/pool_points_interp.cc index 0bbc682..54c6245 100644 --- a/detectron/ops/pool_points_interp.cc +++ b/detectron/ops/pool_points_interp.cc @@ -1,33 +1,66 @@ /** -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. */ - #include "pool_points_interp.h" namespace caffe2 { -//namespace { REGISTER_CPU_OPERATOR(PoolPointsInterp, PoolPointsInterpOp); REGISTER_CPU_OPERATOR(PoolPointsInterpGradient, PoolPointsInterpGradientOp); -// Input: X, points; Output: Y -OPERATOR_SCHEMA(PoolPointsInterp).NumInputs(2).NumOutputs(1); -// Input: X, points, dY (aka "gradOutput"); -// Output: dX (aka "gradInput") -OPERATOR_SCHEMA(PoolPointsInterpGradient).NumInputs(3).NumOutputs(1); +OPERATOR_SCHEMA(PoolPointsInterp) + .NumInputs(2) + .NumOutputs(1) + .Input( + 0, + "X", + "4D feature/heat map input of shape (N, C, H, W).") + .Input( + 1, + "coords", + "2D input of shape (P, 2) specifying P points with 2 columns " + "representing 2D coordinates on the image (x, y). The " + "coordinates have been converted to in the coordinate system of X.") + .Output( + 0, + "Y", + "2D output of shape (P, K). The r-th batch element is a " + "pooled/interpolated index or UV coordinate corresponding " + "to the r-th point over all K patches (including background)."); + +OPERATOR_SCHEMA(PoolPointsInterpGradient) + .NumInputs(3) + .NumOutputs(1) + .Input( + 0, + "X", + "See PoolPointsInterp.") + .Input( + 1, + "coords", + "See PoolPointsInterp.") + .Input( + 2, + "dY", + "Gradient of forward output 0 (Y)") + .Output( + 0, + "dX", + "Gradient of forward input 0 (X)"); class GetPoolPointsInterpGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; vector GetGradientDefs() override { return SingleGradientDef( - "PoolPointsInterpGradient", "", + "PoolPointsInterpGradient", + "", vector{I(0), I(1), GO(0)}, vector{GI(0)}); } @@ -35,5 +68,4 @@ class GetPoolPointsInterpGradient : public GradientMakerBase { REGISTER_GRADIENT(PoolPointsInterp, GetPoolPointsInterpGradient); -//} // namespace } // namespace caffe2 diff --git a/detectron/ops/pool_points_interp.cu b/detectron/ops/pool_points_interp.cu index 6286e2c..7a9f755 100644 --- a/detectron/ops/pool_points_interp.cu +++ b/detectron/ops/pool_points_interp.cu @@ -1,9 +1,9 @@ /** -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. */ #include @@ -28,24 +28,26 @@ float gpu_atomic_add(const float val, float* address) { template __device__ T bilinear_interpolate(const T* bottom_data, - const int height, const int width, - T y, T x, + const int height, const int width, T x, T y, const int index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - //empty + if (x < -1.0 || x > width || y < -1.0 || y > height) { return 0; } - if (y <= 0) y = 0; if (x <= 0) x = 0; + if (y <= 0) y = 0; - int y_low = (int) y; int x_low = (int) x; - int y_high; - int x_high; + int y_low = (int) y; + int x_high, y_high; + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T) x_low; + } else { + x_high = x_low + 1; + } if (y_low >= height - 1) { y_high = y_low = height - 1; y = (T) y_low; @@ -53,82 +55,62 @@ __device__ T bilinear_interpolate(const T* bottom_data, y_high = y_low + 1; } - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (T) x_low; - } else { - x_high = x_low + 1; - } + // lambdas in X, Y axes + T lx = x - x_low, ly = y - y_low; + T hx = 1. - lx, hy = 1. - ly; - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; // do bilinear interpolation - T v1 = bottom_data[y_low * width + x_low]; - T v2 = bottom_data[y_low * width + x_high]; - T v3 = bottom_data[y_high * width + x_low]; - T v4 = bottom_data[y_high * width + x_high]; + T v1 = bottom_data[y_low * width + x_low]; // top-left point + T v2 = bottom_data[y_low * width + x_high]; // top-right point + T v3 = bottom_data[y_high * width + x_low]; // bottom-left point + T v4 = bottom_data[y_high * width + x_high]; // bottom-right point T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - return val; + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; } template -__global__ void PointWarpForward(const int nthreads, const T* bottom_data, - const T spatial_scale, const int channels, - const int height, const int width, - const T* bottom_rois, T* top_data) { +__global__ void PoolPointsInterpForward(const int nthreads, const T* bottom_data, + const T spatial_scale, const int channels, const int height, const int width, + const T* coords, T* top_data) { CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c) is an element in the pooled/interpolated output int c = index % channels; int n = index / channels; - // - const T* offset_bottom_rois = bottom_rois + n * 3; - - int roi_batch_ind = n/196; // Should be original !! - // - T X_point = offset_bottom_rois[1] * spatial_scale; - T Y_point = offset_bottom_rois[2] * spatial_scale; - - + const T* offset_coords = coords + n * 2; + // Get index of current fg roi among all fg rois in a minibatch + int roi_batch_ind = n / 196; + // Get spatial coordinate (x, y) + T x = offset_coords[0] * spatial_scale; + T y = offset_coords[1] * spatial_scale; const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; - - T val = bilinear_interpolate(offset_bottom_data, height, width, Y_point, X_point, index); - top_data[index] = val; + // Compute interpolated value + top_data[index] = bilinear_interpolate( + offset_bottom_data, height, width, x, y, index); } } template __device__ void bilinear_interpolate_gradient( - const int height, const int width, - T y, T x, + const int height, const int width, T x, T y, T & w1, T & w2, T & w3, T & w4, int & x_low, int & x_high, int & y_low, int & y_high, const int index /* index for debug only*/) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - //empty + if (x < -1.0 || x > width || y < -1.0 || y > height) { + // empty w1 = w2 = w3 = w4 = 0.; x_low = x_high = y_low = y_high = -1; return; } - if (y <= 0) y = 0; if (x <= 0) x = 0; + if (y <= 0) y = 0; - y_low = (int) y; x_low = (int) x; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (T) y_low; - } else { - y_high = y_low + 1; - } + y_low = (int) y; if (x_low >= width - 1) { x_high = x_low = width - 1; @@ -136,11 +118,15 @@ __device__ void bilinear_interpolate_gradient( } else { x_high = x_low + 1; } + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T) y_low; + } else { + y_high = y_low + 1; + } - T ly = y - y_low; - T lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - + T lx = x - x_low, ly = y - y_low; + T hx = 1. - lx, hy = 1. - ly; w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; @@ -148,30 +134,23 @@ __device__ void bilinear_interpolate_gradient( } template -__global__ void PointWarpBackwardFeature(const int nthreads, const T* top_diff, - const int num_rois, const T spatial_scale, - const int channels, const int height, const int width, - - T* bottom_diff, - const T* bottom_rois) { +__global__ void PoolPointsInterpBackward(const int nthreads, const T* top_diff, + const int num_rois, const T spatial_scale, const int channels, + const int height, const int width, T* bottom_diff, const T* coords) { CUDA_1D_KERNEL_LOOP(index, nthreads) { - int c = index % channels; - int n = index / channels; - - const T* offset_bottom_rois = bottom_rois + n * 3; - // int roi_batch_ind = offset_bottom_rois[0]; - int roi_batch_ind = n/196; // Should be original !! + int c = index % channels; + int n = index / channels; - T X_point = offset_bottom_rois[1] * spatial_scale; - T Y_point = offset_bottom_rois[2] * spatial_scale; + const T* offset_coords = coords + n * 2; + int roi_batch_ind = n / 196; + T x = offset_coords[0] * spatial_scale; + T y = offset_coords[1] * spatial_scale; T w1, w2, w3, w4; int x_low, x_high, y_low, y_high; - bilinear_interpolate_gradient(height, width, Y_point, X_point, - w1, w2, w3, w4, - x_low, x_high, y_low, y_high, - index); + bilinear_interpolate_gradient(height, width, x, y, + w1, w2, w3, w4, x_low, x_high, y_low, y_high, index); T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; // @@ -193,35 +172,30 @@ __global__ void PointWarpBackwardFeature(const int nthreads, const T* top_diff, } // if } // CUDA_1D_KERNEL_LOOP -} // ROIWarpBackward - +} // PoolPointsInterpBackward } // namespace template<> bool PoolPointsInterpOp::RunOnDevice() { auto& X = Input(0); // Input data to pool - auto& R = Input(1); // RoIs - auto* Y = Output(0); // RoI pooled data + auto& R = Input(1); // Spatial coordinates of all points within RoIs if (R.size() == 0) { // Handle empty rois - Y->Resize(0, X.dim32(1)); - // The following mutable_data calls are needed to allocate the tensors - Y->mutable_data(); + std::vector sizes = {0, X.dim32(1)}; + /* auto* Y = */ Output(0, sizes, at::dtype()); return true; } - Y->Resize(R.dim32(0), X.dim32(1)); + auto* Y = Output(0, {R.dim32(0), X.dim32(1)}, at::dtype()); // Pooled interpolated data int output_size = Y->size(); - PointWarpForward<<>>( + PoolPointsInterpForward<<>>( output_size, X.data(), spatial_scale_, X.dim32(1), X.dim32(2), X.dim32(3), - R.data(), - Y->mutable_data() - ); + R.data(), Y->mutable_data()); return true; } @@ -232,59 +206,46 @@ __global__ void SetKernel(const int N, const T alpha, T* Y) { Y[i] = alpha; } } -} - +} // namespace namespace { - - template __global__ void SetEvenIndsToVal(size_t num_even_inds, T val, T* data) { CUDA_1D_KERNEL_LOOP(i, num_even_inds) { data[i << 1] = val; } } -} - +} // namespace - template<> bool PoolPointsInterpGradientOp::RunOnDevice() { auto& X = Input(0); // Input data to pool - auto& R = Input(1); // RoIs + auto& R = Input(1); // 2D Spatial coordinates of all points within RoIs auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op // (aka "gradOutput") - auto* dX = Output(0); // Gradient of net w.r.t. input to "forward" op - // (aka "gradInput") - - dX->ResizeLike(X); + auto* dX = Output( + 0, X.sizes(), at::dtype()); // Gradient of net w.r.t. input to + // "forward" op (aka "gradInput") - SetKernel - <<size()), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>( - dX->size(), - 0.f, - dX->mutable_data()); + SetKernel<<size()), + CAFFE_CUDA_NUM_THREADS, + 0, context_.cuda_stream()>>>( + dX->size(), 0.f, dX->mutable_data()); if (dY.size() > 0) { // Handle possibly empty gradient if there were no rois - PointWarpBackwardFeature<<>>( + PoolPointsInterpBackward<<>>( dY.size(), dY.data(), R.dim32(0), spatial_scale_, X.dim32(1), X.dim32(2), X.dim32(3), - dX->mutable_data(), - R.data()); + dX->mutable_data(), R.data()); } return true; } -//namespace { REGISTER_CUDA_OPERATOR(PoolPointsInterp, PoolPointsInterpOp); REGISTER_CUDA_OPERATOR(PoolPointsInterpGradient, PoolPointsInterpGradientOp); -//} // namespace } // namespace caffe2 diff --git a/detectron/ops/pool_points_interp.h b/detectron/ops/pool_points_interp.h index 367a29c..02d5c65 100644 --- a/detectron/ops/pool_points_interp.h +++ b/detectron/ops/pool_points_interp.h @@ -1,12 +1,11 @@ /** -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. */ - #ifndef POOL_POINTS_INTERP_OP_H_ #define POOL_POINTS_INTERP_OP_H_ @@ -21,13 +20,14 @@ class PoolPointsInterpOp final : public Operator { public: PoolPointsInterpOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), - spatial_scale_(OperatorBase::GetSingleArgument( + spatial_scale_(this->template GetSingleArgument( "spatial_scale", 1.)) { DCHECK_GT(spatial_scale_, 0); } USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { + // No CPU implementation for now CAFFE_NOT_IMPLEMENTED; } @@ -40,13 +40,14 @@ class PoolPointsInterpGradientOp final : public Operator { public: PoolPointsInterpGradientOp(const OperatorDef& def, Workspace* ws) : Operator(def, ws), - spatial_scale_(OperatorBase::GetSingleArgument( - "spatial_scale", 1.)){ + spatial_scale_(this->template GetSingleArgument( + "spatial_scale", 1.)) { DCHECK_GT(spatial_scale_, 0); } USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { + // No CPU implementation for now CAFFE_NOT_IMPLEMENTED; } @@ -56,4 +57,4 @@ class PoolPointsInterpGradientOp final : public Operator { } // namespace caffe2 -#endif // PoolPointsInterpOp +#endif // POOL_POINTS_INTERP_OP_H_ diff --git a/detectron/roi_data/body_uv_rcnn.py b/detectron/roi_data/body_uv_rcnn.py index 1cb0f0b..27c817a 100644 --- a/detectron/roi_data/body_uv_rcnn.py +++ b/detectron/roi_data/body_uv_rcnn.py @@ -3,18 +3,21 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +############################################################################## + +"""Construct minibatches for DensePose training. Handles the minibatch blobs +that are specific to DensePose. Other blobs that are generic to RPN or +Fast/er R-CNN are handled by their respecitive roi_data modules. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals -# -from scipy.io import loadmat -import copy + import cv2 import logging import numpy as np -# from detectron.core.config import cfg import detectron.utils.blob as blob_utils @@ -22,191 +25,180 @@ import detectron.utils.segms as segm_utils import detectron.utils.densepose_methods as dp_utils -# -from memory_profiler import profile -# -import os -# logger = logging.getLogger(__name__) -# + DP = dp_utils.DensePoseMethods() -# + def add_body_uv_rcnn_blobs(blobs, sampled_boxes, roidb, im_scale, batch_idx): - IsFlipped = roidb['flipped'] + """Add DensePose specific blobs to the given inputs blobs dictionary.""" M = cfg.BODY_UV_RCNN.HEATMAP_SIZE - # + # Prepare the body UV targets by associating one gt box which contains + # body UV annotations to each training roi that has a fg class label. polys_gt_inds = np.where(roidb['ignore_UV_body'] == 0)[0] - boxes_from_polys = [roidb['boxes'][i,:] for i in polys_gt_inds] - if not(boxes_from_polys): - pass - else: - boxes_from_polys = np.vstack(boxes_from_polys) - boxes_from_polys = np.array(boxes_from_polys) - + boxes_from_polys = roidb['boxes'][polys_gt_inds] + # Select foreground RoIs fg_inds = np.where(blobs['labels_int32'] > 0)[0] - roi_has_mask = np.zeros( blobs['labels_int32'].shape ) + roi_has_body_uv = np.zeros_like(blobs['labels_int32'], dtype=np.int32) - if (bool(boxes_from_polys.any()) & (fg_inds.shape[0] > 0) ): + if ((boxes_from_polys.shape[0] > 0) & (fg_inds.shape[0] > 0)): + # Find overlap between all foreground RoIs and the gt bounding boxes + # containing each body UV annotaion. rois_fg = sampled_boxes[fg_inds] - # - rois_fg.astype(np.float32, copy=False) - boxes_from_polys.astype(np.float32, copy=False) - # overlaps_bbfg_bbpolys = box_utils.bbox_overlaps( rois_fg.astype(np.float32, copy=False), - boxes_from_polys.astype(np.float32, copy=False)) + boxes_from_polys.astype(np.float32, copy=False) + ) + # Select foreground RoIs as those with > 0.7 overlap fg_polys_value = np.max(overlaps_bbfg_bbpolys, axis=1) - fg_inds = fg_inds[fg_polys_value>0.7] - - if (bool(boxes_from_polys.any()) & (fg_inds.shape[0] > 0) ): - for jj in fg_inds: - roi_has_mask[jj] = 1 - - # Create blobs for densepose supervision. - ################################################## The mask - All_labels = blob_utils.zeros((fg_inds.shape[0], M ** 2), int32=True) - All_Weights = blob_utils.zeros((fg_inds.shape[0], M ** 2), int32=True) - ################################################# The points - X_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=False) - Y_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=False) - Ind_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=True) + fg_inds = fg_inds[fg_polys_value > 0.7] + + if ((boxes_from_polys.shape[0] > 0) & (fg_inds.shape[0] > 0)): + roi_has_body_uv[fg_inds] = 1 + # Create body UV blobs + # Dense masks, each mask for a given fg roi is of size M x M. + part_inds = blob_utils.zeros((fg_inds.shape[0], M, M), int32=True) + # Weights assigned to each target in `part_inds`. By default, all 1's. + # part_inds_weights = blob_utils.zeros((fg_inds.shape[0], M, M), int32=True) + part_inds_weights = blob_utils.ones((fg_inds.shape[0], M, M), int32=False) + # 2D spatial coordinates (on the image). Shape is (#fg_rois, 2) in format + # (x, y). + coords_xy = blob_utils.zeros((fg_inds.shape[0], 196, 2), int32=False) + # 24 patch indices plus a background class I_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=True) + # UV coordinates in each patch U_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=False) V_points = blob_utils.zeros((fg_inds.shape[0], 196), int32=False) - Uv_point_weights = blob_utils.zeros((fg_inds.shape[0], 196), int32=False) - ################################################# + # Uv_point_weights = blob_utils.zeros((fg_inds.shape[0], 196), int32=False) rois_fg = sampled_boxes[fg_inds] - overlaps_bbfg_bbpolys = box_utils.bbox_overlaps( - rois_fg.astype(np.float32, copy=False), - boxes_from_polys.astype(np.float32, copy=False)) + overlaps_bbfg_bbpolys = overlaps_bbfg_bbpolys[fg_inds] + # Map from each fg roi to the index of the gt box with highest overlap fg_polys_inds = np.argmax(overlaps_bbfg_bbpolys, axis=1) + # Add body UV targets for each fg roi for i in range(rois_fg.shape[0]): - # - fg_polys_ind = polys_gt_inds[ fg_polys_inds[i] ] - # - Ilabel = segm_utils.GetDensePoseMask( roidb['dp_masks'][ fg_polys_ind ] ) - # - GT_I = np.array(roidb['dp_I'][ fg_polys_ind ]) - GT_U = np.array(roidb['dp_U'][ fg_polys_ind ]) - GT_V = np.array(roidb['dp_V'][ fg_polys_ind ]) - GT_x = np.array(roidb['dp_x'][ fg_polys_ind ]) - GT_y = np.array(roidb['dp_y'][ fg_polys_ind ]) - GT_weights = np.ones(GT_I.shape).astype(np.float32) - # - ## Do the flipping of the densepose annotation ! - if(IsFlipped): - GT_I,GT_U,GT_V,GT_x,GT_y,Ilabel = DP.get_symmetric_densepose(GT_I,GT_U,GT_V,GT_x,GT_y,Ilabel) - # + fg_polys_ind = fg_polys_inds[i] + polys_gt_ind = polys_gt_inds[fg_polys_ind] + # RLE encoded dense masks which are of size 256 x 256. + # Map all part masks to 14 labels (i.e., indices of semantic body parts). + dp_masks = dp_utils.GetDensePoseMask( + roidb['dp_masks'][polys_gt_ind], cfg.BODY_UV_RCNN.NUM_SEMANTIC_PARTS + ) + # Surface patch indices of collected points + dp_I = np.array(roidb['dp_I'][polys_gt_ind], dtype=np.int32) + # UV coordinates of collected points + dp_U = np.array(roidb['dp_U'][polys_gt_ind], dtype=np.float32) + dp_V = np.array(roidb['dp_V'][polys_gt_ind], dtype=np.float32) + # dp_UV_weights = np.ones_like(dp_I).astype(np.float32) + # Spatial coordinates on the image which are scaled such that the bbox + # size is 256 x 256. + dp_x = np.array(roidb['dp_x'][polys_gt_ind], dtype=np.float32) + dp_y = np.array(roidb['dp_y'][polys_gt_ind], dtype=np.float32) + # Do the flipping of the densepose annotation + if roidb['flipped']: + dp_I, dp_U, dp_V, dp_x, dp_y, dp_masks = DP.get_symmetric_densepose( + dp_I, dp_U, dp_V, dp_x, dp_y, dp_masks + ) + roi_fg = rois_fg[i] - roi_gt = boxes_from_polys[fg_polys_inds[i],:] - # - x1 = roi_fg[0] ; x2 = roi_fg[2] - y1 = roi_fg[1] ; y2 = roi_fg[3] - # - x1_source = roi_gt[0]; x2_source = roi_gt[2] - y1_source = roi_gt[1]; y2_source = roi_gt[3] - # - x_targets = ( np.arange(x1,x2, (x2 - x1)/M ) - x1_source ) * ( 256. / (x2_source-x1_source) ) - y_targets = ( np.arange(y1,y2, (y2 - y1)/M ) - y1_source ) * ( 256. / (y2_source-y1_source) ) - # - x_targets = x_targets[0:M] ## Strangely sometimes it can be M+1, so make sure size is OK! - y_targets = y_targets[0:M] - # - [X_targets,Y_targets] = np.meshgrid( x_targets, y_targets ) - New_Index = cv2.remap(Ilabel,X_targets.astype(np.float32), Y_targets.astype(np.float32), interpolation=cv2.INTER_NEAREST, borderMode= cv2.BORDER_CONSTANT, borderValue=(0)) - # - All_L = np.zeros(New_Index.shape) - All_W = np.ones(New_Index.shape) - # - All_L = New_Index - # - gt_length_x = x2_source - x1_source - gt_length_y = y2_source - y1_source - # - GT_y = (( GT_y / 256. * gt_length_y ) + y1_source - y1 ) * ( M / ( y2 - y1 ) ) - GT_x = (( GT_x / 256. * gt_length_x ) + x1_source - x1 ) * ( M / ( x2 - x1 ) ) - # - GT_I[GT_y<0] = 0 - GT_I[GT_y>(M-1)] = 0 - GT_I[GT_x<0] = 0 - GT_I[GT_x>(M-1)] = 0 - # - points_inside = GT_I>0 - GT_U = GT_U[points_inside] - GT_V = GT_V[points_inside] - GT_x = GT_x[points_inside] - GT_y = GT_y[points_inside] - GT_weights = GT_weights[points_inside] - GT_I = GT_I[points_inside] - # - X_points[i, 0:len(GT_x)] = GT_x - Y_points[i, 0:len(GT_y)] = GT_y - Ind_points[i, 0:len(GT_I)] = i - I_points[i, 0:len(GT_I)] = GT_I - U_points[i, 0:len(GT_U)] = GT_U - V_points[i, 0:len(GT_V)] = GT_V - Uv_point_weights[i, 0:len(GT_weights)] = GT_weights - # - All_labels[i, :] = np.reshape(All_L.astype(np.int32), M ** 2) - All_Weights[i, :] = np.reshape(All_W.astype(np.int32), M ** 2) - ## - else: + gt_box = boxes_from_polys[fg_polys_ind] + fg_x1, fg_y1, fg_x2, fg_y2 = roi_fg[0:4] + gt_x1, gt_y1, gt_x2, gt_y2 = gt_box[0:4] + fg_width = fg_x2 - fg_x1; fg_height = fg_y2 - fg_y1 + gt_width = gt_x2 - gt_x1; gt_height = gt_y2 - gt_y1 + fg_scale_w = float(M) / fg_width + fg_scale_h = float(M) / fg_height + gt_scale_w = 256. / gt_width + gt_scale_h = 256. / gt_height + # Sample M points evenly within the fg roi and scale the relative coordinates + # (to associated gt box) such that the bounding box size is 256 x 256. + x_targets = (np.arange(fg_x1, fg_x2, fg_width / M) - gt_x1) * gt_scale_w + y_targets = (np.arange(fg_y1, fg_y2, fg_height / M) - gt_y1) * gt_scale_h + # Construct 2D coordiante matrices + x_targets, y_targets = np.meshgrid(x_targets[:M], y_targets[:M]) + ## Another implementation option (which results in similar performance) + # x_targets = (np.linspace(fg_x1, fg_x2, M, endpoint=True, dtype=np.float32) - gt_x1) * gt_scale_w + # y_targets = (np.linspace(fg_y1, fg_y2, M, endpoint=True, dtype=np.float32) - gt_y1) * gt_scale_h + # x_targets = (np.linspace(fg_x1, fg_x2, M, endpoint=False) - gt_x1) * gt_scale_w + # y_targets = (np.linspace(fg_y1, fg_y2, M, endpoint=False) - gt_y1) * gt_scale_h + # x_targets, y_targets = np.meshgrid(x_targets, y_targets) + + # Map dense masks of size 256 x 256 to target heatmap of size M x M. + part_inds[i] = cv2.remap( + dp_masks, x_targets.astype(np.float32), y_targets.astype(np.float32), + interpolation=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, borderValue=(0) + ) + + # Scale annotated spatial coordinates from bbox of size 256 x 256 to target + # heatmap of size M x M. + dp_x = (dp_x / gt_scale_w + gt_x1 - fg_x1) * fg_scale_w + dp_y = (dp_y / gt_scale_h + gt_y1 - fg_y1) * fg_scale_h + # Set patch index of points outside the heatmap as 0 (background). + dp_I[dp_x < 0] = 0; dp_I[dp_x > (M - 1)] = 0 + dp_I[dp_y < 0] = 0; dp_I[dp_y > (M - 1)] = 0 + # Get body UV annotations of points inside the heatmap. + points_inside = dp_I > 0 + dp_x = dp_x[points_inside] + dp_y = dp_y[points_inside] + dp_I = dp_I[points_inside] + dp_U = dp_U[points_inside] + dp_V = dp_V[points_inside] + # dp_UV_weights = dp_UV_weights[points_inside] + + # Update body UV blobs + num_dp_points = len(dp_I) + # coords_xy[i, 0:num_dp_points, 0] = i # fg_roi index + coords_xy[i, 0:num_dp_points, 0] = dp_x + coords_xy[i, 0:num_dp_points, 1] = dp_y + I_points[i, 0:num_dp_points] = dp_I.astype(np.int32) + U_points[i, 0:num_dp_points] = dp_U + V_points[i, 0:num_dp_points] = dp_V + # Uv_point_weights[i, 0:len(dp_UV_weights)] = dp_UV_weights + else: # If there are no fg rois + # The network cannot handle empty blobs, so we must provide a blob. + # We simply take the first bg roi, give it an all 0's body UV annotations + # and label it with class zero (bg). bg_inds = np.where(blobs['labels_int32'] == 0)[0] - # - if(len(bg_inds)==0): + # `rois_fg` is actually one background roi, but that's ok because ... + if len(bg_inds) == 0: rois_fg = sampled_boxes[0].reshape((1, -1)) else: rois_fg = sampled_boxes[bg_inds[0]].reshape((1, -1)) + # Mark that the first roi has body UV annotation + roi_has_body_uv[0] = 1 + # We give it all 0's blobs + part_inds = blob_utils.zeros((1, M, M), int32=True) + part_inds_weights = blob_utils.zeros((1, M, M), int32=False) + coords_xy = blob_utils.zeros((1, 196, 2), int32=False) + I_points = blob_utils.zeros((1, 196), int32=True) + U_points = blob_utils.zeros((1, 196), int32=False) + V_points = blob_utils.zeros((1, 196), int32=False) + # Uv_point_weights = blob_utils.zeros((1, 196), int32=False) - roi_has_mask[0] = 1 - # - X_points = blob_utils.zeros((1, 196), int32=False) - Y_points = blob_utils.zeros((1, 196), int32=False) - Ind_points = blob_utils.zeros((1, 196), int32=True) - I_points = blob_utils.zeros((1,196), int32=True) - U_points = blob_utils.zeros((1, 196), int32=False) - V_points = blob_utils.zeros((1, 196), int32=False) - Uv_point_weights = blob_utils.zeros((1, 196), int32=False) - # - All_labels = -blob_utils.ones((1, M ** 2), int32=True) * 0 ## zeros - All_Weights = -blob_utils.ones((1, M ** 2), int32=True) * 0 ## zeros - # + # Scale rois_fg and format as (batch_idx, x1, y1, x2, y2) rois_fg *= im_scale repeated_batch_idx = batch_idx * blob_utils.ones((rois_fg.shape[0], 1)) rois_fg = np.hstack((repeated_batch_idx, rois_fg)) - # - K = cfg.BODY_UV_RCNN.NUM_PATCHES - # - U_points = np.tile( U_points , [1,K+1] ) - V_points = np.tile( V_points , [1,K+1] ) - Uv_Weight_Points = np.zeros(U_points.shape) - # - for jjj in xrange(1,K+1): - Uv_Weight_Points[ : , jjj * I_points.shape[1] : (jjj+1) * I_points.shape[1] ] = ( I_points == jjj ).astype(np.float32) - # - ################ - # Update blobs dict with Mask R-CNN blobs - ############### - # - blobs['body_uv_rois'] = np.array(rois_fg) - blobs['roi_has_body_uv_int32'] = np.array(roi_has_mask).astype(np.int32) - ## - blobs['body_uv_ann_labels'] = np.array(All_labels).astype(np.int32) - blobs['body_uv_ann_weights'] = np.array(All_Weights).astype(np.float32) - # - ########################## - blobs['body_uv_X_points'] = X_points.astype(np.float32) - blobs['body_uv_Y_points'] = Y_points.astype(np.float32) - blobs['body_uv_Ind_points'] = Ind_points.astype(np.float32) - blobs['body_uv_I_points'] = I_points.astype(np.float32) - blobs['body_uv_U_points'] = U_points.astype(np.float32) #### VERY IMPORTANT : These are switched here : - blobs['body_uv_V_points'] = V_points.astype(np.float32) - blobs['body_uv_point_weights'] = Uv_Weight_Points.astype(np.float32) - ################### - - - + # Create body UV blobs for all patches (including background) + K = cfg.BODY_UV_RCNN.NUM_PATCHES + 1 + # Construct U/V_points blobs for all patches by repeating it #num_patches times. + # Shape: (#rois, 196, K) + U_points = np.repeat(U_points[:, :, np.newaxis], K, axis=-1) + V_points = np.repeat(V_points[:, :, np.newaxis], K, axis=-1) + uv_point_weights = np.zeros_like(U_points) + # Set binary weights for UV targets in each patch + for i in np.arange(1, K): + uv_point_weights[:, :, i] = (I_points == i).astype(np.float32) + # Update blobs dict with body UV blobs + blobs['body_uv_rois'] = rois_fg + blobs['roi_has_body_uv_int32'] = roi_has_body_uv # shape: (#rois,) + blobs['body_uv_parts'] = part_inds # shape: (#rois, M, M) + blobs['body_uv_parts_weights'] = part_inds_weights + blobs['body_uv_coords_xy'] = coords_xy.reshape(-1, 2) # shape: (#rois * 196, 2) + blobs['body_uv_I_points'] = I_points.reshape(-1, 1) # shape: (#rois * 196, 1) + blobs['body_uv_U_points'] = U_points # shape: (#rois, 196, K) + blobs['body_uv_V_points'] = V_points + blobs['body_uv_point_weights'] = uv_point_weights diff --git a/detectron/roi_data/fast_rcnn.py b/detectron/roi_data/fast_rcnn.py index 2635974..153ce61 100644 --- a/detectron/roi_data/fast_rcnn.py +++ b/detectron/roi_data/fast_rcnn.py @@ -41,7 +41,6 @@ def get_fast_rcnn_blob_names(is_training=True): # labels_int32 blob: R categorical labels in [0, ..., K] for K # foreground classes plus background blob_names += ['labels_int32'] - if is_training: # bbox_targets blob: R bounding-box regression targets with 4 # targets per class blob_names += ['bbox_targets'] @@ -81,19 +80,32 @@ def get_fast_rcnn_blob_names(is_training=True): ######################## if is_training and cfg.MODEL.BODY_UV_ON: + # 'body_uv_rois': RoIs sampled for training the body UV estimation branch. + # Shape is (#fg_rois, 5) in format (batch_idx, x1, y1, x2, y2). blob_names += ['body_uv_rois'] + # 'roi_has_body_uv': binary labels for the RoIs specified in 'rois' + # indicating if each RoI has a body or not. Shape is (#rois). blob_names += ['roi_has_body_uv_int32'] - ######### - # ################################################### - blob_names += ['body_uv_ann_labels'] - blob_names += ['body_uv_ann_weights'] - # ################################################# - blob_names += ['body_uv_X_points'] - blob_names += ['body_uv_Y_points'] - blob_names += ['body_uv_Ind_points'] + # 'body_uv_parts': index of part in [0, ..., S] where S is the number of + # semantic parts used to sample body UV points for the RoIs specified in + # 'body_uv_rois'. Shape is (#rois, M, M) where M is the heat map size. + blob_names += ['body_uv_parts'] + # 'body_uv_parts_weights': weight assigned to each target in 'body_uv_parts'. + # Shape is (#rois, M, M). Used in SpatialSoftmaxWithLoss. + blob_names += ['body_uv_parts_weights'] + # 'body_uv_coords_xy': 2D spatial coordinates of collected points on + # the image. Shape is (#rois * 196, 2) in format (dp_x, dp_y). + # Used in PoolPointsInterp. + blob_names += ['body_uv_coords_xy'] + # 'body_uv_I_points': surface patch indices in [0, ..., K] for K patches + # plus background. Shape is (#rois * 196, 1). Used in SoftmaxWithLoss. blob_names += ['body_uv_I_points'] + # 'body_uv_U/V_points': UV coordinates of collected points in each patch. + # Shape is (#rois, 196, K). Used in PoolPointsInterp and SmoothL1Loss. blob_names += ['body_uv_U_points'] blob_names += ['body_uv_V_points'] + # 'body_uv_point_weights': weight assigned to each target in + # 'body_uv_U/V_points'. Shape is (#rois, 196, K). Used in SmoothL1Loss. blob_names += ['body_uv_point_weights'] if cfg.FPN.FPN_ON and cfg.FPN.MULTILEVEL_ROIS: @@ -173,7 +185,7 @@ def _sample_rois(roidb, im_scale, batch_idx): # against there being fewer than desired) bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, bg_inds.size) - # Sample foreground regions without replacement + # Sample background regions without replacement if bg_inds.size > 0: bg_inds = npr.choice( bg_inds, size=bg_rois_per_this_image, replace=False diff --git a/detectron/roi_data/rpn.py b/detectron/roi_data/rpn.py index 63b0166..0aa8266 100644 --- a/detectron/roi_data/rpn.py +++ b/detectron/roi_data/rpn.py @@ -113,7 +113,9 @@ def add_rpn_blobs(blobs, im_scales, roidb): valid_keys = [ 'has_visible_keypoints', 'boxes', 'segms', 'seg_areas', 'gt_classes', - 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints','flipped', 'ignore_UV_body','dp_x','dp_y','dp_I','dp_U','dp_V','dp_masks' ] + 'gt_overlaps', 'is_crowd', 'box_to_gt_ind_map', 'gt_keypoints', 'flipped', + 'ignore_UV_body', 'dp_x', 'dp_y', 'dp_I', 'dp_U', 'dp_V', 'dp_masks' + ] minimal_roidb = [{} for _ in range(len(roidb))] for i, e in enumerate(roidb): for k in valid_keys: diff --git a/detectron/utils/densepose_methods.py b/detectron/utils/densepose_methods.py index 4d1dcf9..3c19aa9 100644 --- a/detectron/utils/densepose_methods.py +++ b/detectron/utils/densepose_methods.py @@ -3,141 +3,137 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +############################################################################## +"""DensePose utilities.""" + +from scipy.io import loadmat +import os.path as osp import numpy as np -import copy -import cv2 -from scipy.io import loadmat -import scipy.spatial.distance -import os +import scipy.spatial.distance as ssd +import pycocotools.mask as mask_util + + +def GetDensePoseMask(Polys, num_parts=14): + """Get dense masks from the encoded masks.""" + MaskGen = np.zeros((256, 256), dtype=np.int32) + for i in range(1, num_parts + 1): + if Polys[i - 1]: + current_mask = mask_util.decode(Polys[i - 1]) + MaskGen[current_mask > 0] = i + return MaskGen class DensePoseMethods: def __init__(self): - # - ALP_UV = loadmat( os.path.join(os.path.dirname(__file__), '../../DensePoseData/UV_data/UV_Processed.mat') ) - self.FaceIndices = np.array( ALP_UV['All_FaceIndices']).squeeze() - self.FacesDensePose = ALP_UV['All_Faces']-1 + ALP_UV = loadmat( + osp.join(osp.dirname(__file__), '../../DensePoseData/UV_data/UV_Processed.mat') + ) + self.FaceIndices = np.array(ALP_UV['All_FaceIndices']).squeeze() + self.FacesDensePose = ALP_UV['All_Faces'] - 1 self.U_norm = ALP_UV['All_U_norm'].squeeze() self.V_norm = ALP_UV['All_V_norm'].squeeze() - self.All_vertices = ALP_UV['All_vertices'][0] - ## Info to compute symmetries. - self.SemanticMaskSymmetries = [0,1,3,2,5,4,7,6,9,8,11,10,13,12,14] - self.Index_Symmetry_List = [1,2,4,3,6,5,8,7,10,9,12,11,14,13,16,15,18,17,20,19,22,21,24,23]; - UV_symmetry_filename = os.path.join(os.path.dirname(__file__), '../../DensePoseData/UV_data/UV_symmetry_transforms.mat') - self.UV_symmetry_transformations = loadmat( UV_symmetry_filename ) - - - def get_symmetric_densepose(self,I,U,V,x,y,Mask): - ### This is a function to get the mirror symmetric UV labels. - Labels_sym= np.zeros(I.shape) - U_sym= np.zeros(U.shape) - V_sym= np.zeros(V.shape) - ### - for i in ( range(24)): - if i+1 in I: - Labels_sym[I == (i+1)] = self.Index_Symmetry_List[i] - jj = np.where(I == (i+1)) - ### - U_loc = (U[jj]*255).astype(np.int64) - V_loc = (V[jj]*255).astype(np.int64) - ### - V_sym[jj] = self.UV_symmetry_transformations['V_transforms'][0,i][V_loc,U_loc] - U_sym[jj] = self.UV_symmetry_transformations['U_transforms'][0,i][V_loc,U_loc] - ## - Mask_flip = np.fliplr(Mask) - Mask_flipped = np.zeros(Mask.shape) - # - for i in ( range(14)): - Mask_flipped[Mask_flip == (i+1)] = self.SemanticMaskSymmetries[i+1] - # - [y_max , x_max ] = Mask_flip.shape - y_sym = y - x_sym = x_max-x - # - return Labels_sym , U_sym , V_sym , x_sym , y_sym , Mask_flipped - - - - def barycentric_coordinates_exists(self,P0, P1, P2, P): - u = P1 - P0 - v = P2 - P0 - w = P - P0 - # - vCrossW = np.cross(v,w) + self.All_vertices = ALP_UV['All_vertices'][0] + self.SemanticMaskSymmetries = [ + 0, 1, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 14 + ] + self.Index_Symmetry_List = [ + 1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, 23 + ] + self.UV_symmetry_transformations = loadmat( + osp.join(osp.dirname(__file__), '../../DensePoseData/UV_data/UV_symmetry_transforms.mat') + ) + + def get_symmetric_densepose(self, I, U, V, x, y, mask): + """Get the mirror symmetric UV annotations""" + symm_I = np.zeros_like(I) + symm_U = np.zeros_like(U) + symm_V = np.zeros_like(V) + for i in range(24): + inds = np.where(I == (i + 1))[0] + if len(inds) > 0: + symm_I[inds] = self.Index_Symmetry_List[i] + loc_U = (U[inds] * 255).astype(np.int32) + loc_V = (V[inds] * 255).astype(np.int32) + symm_U[inds] = self.UV_symmetry_transformations['U_transforms'][0, i][loc_V, loc_U] + symm_V[inds] = self.UV_symmetry_transformations['V_transforms'][0, i][loc_V, loc_U] + + flip_mask = np.fliplr(mask) + symm_mask = np.zeros_like(mask) + for i in range(1, 15): + symm_mask[flip_mask == i] = self.SemanticMaskSymmetries[i] + x_max = flip_mask.shape[1] + symm_x = x_max - x + symm_y = y + return symm_I, symm_U, symm_V, symm_x, symm_y, symm_mask + + def barycentric_coordinates_exists(self, P0, P1, P2, P): + u = P1 - P0; v = P2 - P0; w = P - P0 + vCrossW = np.cross(v, w) vCrossU = np.cross(v, u) - if (np.dot(vCrossW, vCrossU) < 0): - return False; - # + if np.dot(vCrossW, vCrossU) < 0: + return False + uCrossW = np.cross(u, w) uCrossV = np.cross(u, v) - # if (np.dot(uCrossW, uCrossV) < 0): - return False; - # - denom = np.sqrt((uCrossV**2).sum()) - r = np.sqrt((vCrossW**2).sum())/denom - t = np.sqrt((uCrossW**2).sum())/denom - # - return((r <=1) & (t <= 1) & (r + t <= 1)) - - def barycentric_coordinates(self,P0, P1, P2, P): - u = P1 - P0 - v = P2 - P0 - w = P - P0 - # - vCrossW = np.cross(v,w) + return False + + denom = np.sqrt((uCrossV ** 2).sum()) + r = np.sqrt((vCrossW ** 2).sum()) / denom + t = np.sqrt((uCrossW ** 2).sum()) / denom + return ((r <= 1) & (t <= 1) & (r + t <= 1)) + + def barycentric_coordinates(self, P0, P1, P2, P): + u = P1 - P0; v = P2 - P0; w = P - P0 + vCrossW = np.cross(v, w) vCrossU = np.cross(v, u) - # + if np.dot(vCrossW, vCrossU) < 0: + return -1, -1, -1 uCrossW = np.cross(u, w) uCrossV = np.cross(u, v) - # - denom = np.sqrt((uCrossV**2).sum()) - r = np.sqrt((vCrossW**2).sum())/denom - t = np.sqrt((uCrossW**2).sum())/denom - # - return(1-(r+t),r,t) - - def IUV2FBC( self, I_point , U_point, V_point): - P = [ U_point , V_point , 0 ] - FaceIndicesNow = np.where( self.FaceIndices == I_point ) - FacesNow = self.FacesDensePose[FaceIndicesNow] - # - P_0 = np.vstack( (self.U_norm[FacesNow][:,0], self.V_norm[FacesNow][:,0], np.zeros(self.U_norm[FacesNow][:,0].shape))).transpose() - P_1 = np.vstack( (self.U_norm[FacesNow][:,1], self.V_norm[FacesNow][:,1], np.zeros(self.U_norm[FacesNow][:,1].shape))).transpose() - P_2 = np.vstack( (self.U_norm[FacesNow][:,2], self.V_norm[FacesNow][:,2], np.zeros(self.U_norm[FacesNow][:,2].shape))).transpose() - # - - for i, [P0,P1,P2] in enumerate( zip(P_0,P_1,P_2)) : - if(self.barycentric_coordinates_exists(P0, P1, P2, P)): - [bc1,bc2,bc3] = self.barycentric_coordinates(P0, P1, P2, P) - return(FaceIndicesNow[0][i],bc1,bc2,bc3) - # + if np.dot(uCrossW, uCrossV) < 0: + return -1, -1, -1 + denom = np.sqrt((uCrossV ** 2).sum()) + r = np.sqrt((vCrossW ** 2).sum()) / denom + t = np.sqrt((uCrossW ** 2).sum()) / denom + if ((r <= 1) & (t <= 1) & (r + t <= 1)): + return 1 - (r + t), r, t + else: + return -1, -1, -1 + + def IUV2FBC(self, I_point, U_point, V_point): + """Convert IUV to FBC (faceIndex and barycentric coordinates).""" + P = [U_point, V_point, 0] + faceIndicesNow = np.where(self.FaceIndices == I_point)[0] + FacesNow = self.FacesDensePose[faceIndicesNow] + v0 = np.zeros_like(self.U_norm[FacesNow][:, 0]) + P_0 = np.vstack((self.U_norm[FacesNow][:, 0], self.V_norm[FacesNow][:, 0], v0)).transpose() + P_1 = np.vstack((self.U_norm[FacesNow][:, 1], self.V_norm[FacesNow][:, 1], v0)).transpose() + P_2 = np.vstack((self.U_norm[FacesNow][:, 2], self.V_norm[FacesNow][:, 2], v0)).transpose() + + for i, [P0, P1, P2] in enumerate(zip(P_0, P_1, P_2)) : + bc1, bc2, bc3 = self.barycentric_coordinates(P0, P1, P2, P) + if bc1 != -1: + return faceIndicesNow[i], bc1, bc2, bc3 + # If the found UV is not inside any faces, select the vertex that is closest! - # - D1 = scipy.spatial.distance.cdist( np.array( [U_point,V_point])[np.newaxis,:] , P_0[:,0:2]).squeeze() - D2 = scipy.spatial.distance.cdist( np.array( [U_point,V_point])[np.newaxis,:] , P_1[:,0:2]).squeeze() - D3 = scipy.spatial.distance.cdist( np.array( [U_point,V_point])[np.newaxis,:] , P_2[:,0:2]).squeeze() - # - minD1 = D1.min() - minD2 = D2.min() - minD3 = D3.min() - # - if((minD1< minD2) & (minD1< minD3)): - return( FaceIndicesNow[0][np.argmin(D1)] , 1.,0.,0. ) - elif((minD2< minD1) & (minD2< minD3)): - return( FaceIndicesNow[0][np.argmin(D2)] , 0.,1.,0. ) + D1 = ssd.cdist(np.array([U_point, V_point])[np.newaxis, :], P_0[:, 0:2]).squeeze() + D2 = ssd.cdist(np.array([U_point, V_point])[np.newaxis, :], P_1[:, 0:2]).squeeze() + D3 = ssd.cdist(np.array([U_point, V_point])[np.newaxis, :], P_2[:, 0:2]).squeeze() + minD1 = D1.min(); minD2 = D2.min(); minD3 = D3.min() + if ((minD1 < minD2) & (minD1 < minD3)): + return faceIndicesNow[np.argmin(D1)], 1., 0., 0. + elif ((minD2 < minD1) & (minD2 < minD3)): + return faceIndicesNow[np.argmin(D2)], 0., 1., 0. else: - return( FaceIndicesNow[0][np.argmin(D3)] , 0.,0.,1. ) - - - def FBC2PointOnSurface( self, FaceIndex, bc1,bc2,bc3,Vertices ): - ## - Vert_indices = self.All_vertices[self.FacesDensePose[FaceIndex]]-1 - ## - p = Vertices[Vert_indices[0],:] * bc1 + \ - Vertices[Vert_indices[1],:] * bc2 + \ - Vertices[Vert_indices[2],:] * bc3 - ## - return(p) - + return faceIndicesNow[np.argmin(D3)], 0., 0., 1. + + def FBC2PointOnSurface(self, face_ind, bc1, bc2, bc3, vertices): + """Use FBC to get 3D coordinates on the surface.""" + Vert_indices = self.All_vertices[self.FacesDensePose[face_ind]] - 1 + # p = vertices[Vert_indices[0], :] * bc1 + \ + # vertices[Vert_indices[1], :] * bc2 + \ + # vertices[Vert_indices[2], :] * bc3 + p = np.matmul(np.array([[bc1, bc2, bc3]]), vertices[Vert_indices]).squeeze() + return p diff --git a/detectron/utils/io.py b/detectron/utils/io.py index 3ec5e22..2ac2fdc 100644 --- a/detectron/utils/io.py +++ b/detectron/utils/io.py @@ -35,6 +35,15 @@ def cache_url(url_or_file, cache_dir): path to the cached file. If the argument is not a URL, simply return it as is. """ + if re.match(r'^\$', url_or_file, re.IGNORECASE) is not None: + url_or_file = os.path.expandvars(url_or_file) + assert os.path.exists(url_or_file) + return url_or_file + elif re.match(r'^~', url_or_file, re.IGNORECASE) is not None: + url_or_file = os.path.expanduser(url_or_file) + assert os.path.exists(url_or_file) + return url_or_file + is_url = re.match(r'^(?:http)s?://', url_or_file, re.IGNORECASE) is not None if not is_url: @@ -42,8 +51,8 @@ def cache_url(url_or_file, cache_dir): # url = url_or_file # - Len_filename = len( url.split('/')[-1] ) - BASE_URL = url[0:-Len_filename-1] + Len_filename = len(url.split('/')[-1]) + BASE_URL = url[0:-Len_filename - 1] # cache_file_path = url.replace(BASE_URL, cache_dir) if os.path.exists(cache_file_path): diff --git a/detectron/utils/segms.py b/detectron/utils/segms.py index bf1ac3f..9967a24 100644 --- a/detectron/utils/segms.py +++ b/detectron/utils/segms.py @@ -20,19 +20,9 @@ from __future__ import unicode_literals import numpy as np - import pycocotools.mask as mask_util -def GetDensePoseMask(Polys): - MaskGen = np.zeros([256,256]) - for i in range(1,15): - if(Polys[i-1]): - current_mask = mask_util.decode(Polys[i-1]) - MaskGen[current_mask>0] = i - return MaskGen - - def flip_segms(segms, height, width): """Left/right flip each mask in a list of masks.""" def _flip_poly(poly, width): diff --git a/detectron/utils/vis.py b/detectron/utils/vis.py index 31b2e79..13556cd 100644 --- a/detectron/utils/vis.py +++ b/detectron/utils/vis.py @@ -376,45 +376,54 @@ def vis_one_image( line, color=colors[len(kp_lines) + 1], linewidth=1.0, alpha=0.7) - # DensePose Visualization Starts!! - ## Get full IUV image out + ### DensePose Visualization Starts!! + # get full IUV image outputs for all bboxes IUV_fields = body_uv[1] - # - All_Coords = np.zeros(im.shape) - All_inds = np.zeros([im.shape[0],im.shape[1]]) - K = 26 - ## - inds = np.argsort(boxes[:,4]) - ## + # initialize IUV output and INDS output images with zeros + All_coords = np.zeros(im.shape, dtype=np.float32) # shape: (im_height, im_width, 3) + All_inds = np.zeros([im.shape[0], im.shape[1]], dtype=np.float32) # shape: (im_height, im_width) + + # display in smallest to largest class scores order, however, this may cause sharpness in some body parts + # due to the output of an inaccurate bbox with lower score will not be overlapped by the output of a more + # precise bbox with higher score which will be discarded. + inds = np.argsort(boxes[:, 4]) for i, ind in enumerate(inds): - entry = boxes[ind,:] - if entry[4] > 0.65: - entry=entry[0:4].astype(int) - #### + score = boxes[ind, 4] + bbox = boxes[ind, :4] + if score > 0.65: + # top left corner (x1, y1) of current bbox in image space + x1, y1 = boxes[ind, :2].astype(int) + # get IUV output for current bbox output = IUV_fields[ind] - #### - All_Coords_Old = All_Coords[ entry[1] : entry[1]+output.shape[1],entry[0]:entry[0]+output.shape[2],:] - All_Coords_Old[All_Coords_Old==0]=output.transpose([1,2,0])[All_Coords_Old==0] - All_Coords[ entry[1] : entry[1]+output.shape[1],entry[0]:entry[0]+output.shape[2],:]= All_Coords_Old - ### - CurrentMask = (output[0,:,:]>0).astype(np.float32) - All_inds_old = All_inds[ entry[1] : entry[1]+output.shape[1],entry[0]:entry[0]+output.shape[2]] - All_inds_old[All_inds_old==0] = CurrentMask[All_inds_old==0]*i - All_inds[ entry[1] : entry[1]+output.shape[1],entry[0]:entry[0]+output.shape[2]] = All_inds_old - # - All_Coords[:,:,1:3] = 255. * All_Coords[:,:,1:3] - All_Coords[All_Coords>255] = 255. - All_Coords = All_Coords.astype(np.uint8) + out_height, out_width = output.shape[1:3] + + # first, locate the region of current bbox on final IUV output image + All_coords_tmp = All_coords[y1:y1 + out_height, x1:x1 + out_width] + # then, extract IUV output of pixels within this bbox in which have not been filled with IUV of other bbox + All_coords_tmp[All_coords_tmp == 0] = output.transpose([1, 2, 0])[All_coords_tmp == 0] + # update final IUV output image + All_coords[y1:y1 + out_height, x1:x1 + out_width] = All_coords_tmp + + # get (distinct) human-body FG mask indices for each bbox + index_UV = output[0] # predicted part index: 0 ~ 24 + CurrentMask = (index_UV > 0).astype(np.float32) + All_inds_tmp = All_inds[y1:y1 + out_height, x1:x1 + out_width] + All_inds_tmp[All_inds_tmp == 0] = CurrentMask[All_inds_tmp == 0] * (i + 1) # avoid `i` starting with 0 + All_inds[y1:y1 + out_height, x1:x1 + out_width] = All_inds_tmp + + # scale predicted UV coordinates to [0, 255] + All_coords[:, :, 1:3] = All_coords[:, :, 1:3] * 255. + All_coords[All_coords > 255] = 255. + All_coords = All_coords.astype(np.uint8) All_inds = All_inds.astype(np.uint8) - # - IUV_SaveName = os.path.basename(im_name).split('.')[0]+'_IUV.png' - INDS_SaveName = os.path.basename(im_name).split('.')[0]+'_INDS.png' - cv2.imwrite(os.path.join(output_dir, '{}'.format(IUV_SaveName)), All_Coords ) - cv2.imwrite(os.path.join(output_dir, '{}'.format(INDS_SaveName)), All_inds ) - print('IUV written to: ' , os.path.join(output_dir, '{}'.format(IUV_SaveName)) ) - ### + # save IUV images into files + IUV_SaveName = os.path.basename(im_name).split('.')[0] + '_IUV.png' + INDS_SaveName = os.path.basename(im_name).split('.')[0] + '_INDS.png' + cv2.imwrite(os.path.join(output_dir, '{}'.format(IUV_SaveName)), All_coords) + cv2.imwrite(os.path.join(output_dir, '{}'.format(INDS_SaveName)), All_inds) + print('IUV written to: ', os.path.join(output_dir, '{}'.format(IUV_SaveName))) ### DensePose Visualization Done!! - # + output_name = os.path.basename(im_name) + '.' + ext fig.savefig(os.path.join(output_dir, '{}'.format(output_name)), dpi=dpi) plt.close('all')