Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .accuracy import Accuracy
from .coco_detection import COCODetectionMetric
from .end_point_error import EndPointError
from .f_metric import F1Metric
from .hmean_iou import HmeanIoU
Expand All @@ -10,5 +11,5 @@

__all__ = [
'Accuracy', 'MeanIoU', 'VOCMeanAP', 'OIDMeanAP', 'EndPointError',
'F1Metric', 'HmeanIoU'
'F1Metric', 'HmeanIoU', 'COCODetectionMetric'
]
618 changes: 618 additions & 0 deletions mmeval/metrics/coco_detection.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion mmeval/metrics/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .hmean import compute_hmean
from .polygon import (poly2shapely, poly_intersection, poly_iou,
poly_make_valid, poly_union, polys2shapely)
Expand Down
191 changes: 191 additions & 0 deletions mmeval/metrics/utils/coco_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pycocotools.mask as _mask_util
from collections import defaultdict
from pathlib import Path
from pycocotools.coco import COCO as _COCO
from pycocotools.cocoeval import COCOeval as _COCOeval
from typing import Dict, Optional, Sequence, Union


class COCO(_COCO):
"""This class is almost the same as official pycocotools package.

It implements some snake case function aliases. So that the COCO class has
the same interface as LVIS class.

Args:
annotation_file (str, optional): Path of annotation file.
Defaults to None.
"""

def __init__(self,
annotation_file: Optional[Union[str, Path]] = None) -> None:
super().__init__(annotation_file=annotation_file)
self.img_ann_map = self.imgToAnns
self.cat_img_map = self.catToImgs

def get_ann_ids(self,
img_ids: Union[list, int] = [],
cat_ids: Union[list, int] = [],
area_rng: Union[list, int] = [],
iscrowd: Optional[bool] = None) -> list:
"""Get annotation ids that satisfy given filter conditions.

Args:
img_ids (list | int): Get annotations for given images.
cat_ids (list | int): Get categories for given images.
area_rng (list | int): Get annotations for given area range.
iscrowd (bool, optional): Get annotations for given crowd label.

Returns:
List: Integer array of annotation ids.
"""
return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)

def get_cat_ids(self,
cat_names: Union[list, int] = [],
sup_names: Union[list, int] = [],
cat_ids: Union[list, int] = []) -> list:
"""Get category ids that satisfy given filter conditions.

Args:
cat_names (list | int): Get categories for given category names.
sup_names (list | int): Get categories for given supercategory
names.
cat_ids (list | int): Get categories for given category ids.

Returns:
List: Integer array of category ids.
"""
return self.getCatIds(cat_names, sup_names, cat_ids)

def get_img_ids(self,
img_ids: Union[list, int] = [],
cat_ids: Union[list, int] = []) -> list:
"""Get image ids that satisfy given filter conditions.

Args:
img_ids (list | int): Get images for given ids
cat_ids (list | int): Get images with all given cats

Returns:
List: Integer array of image ids.
"""
return self.getImgIds(img_ids, cat_ids)

def load_anns(self, ids: Union[list, int] = []) -> list:
"""Load annotations with the specified ids.

Args:
ids (list | int): Integer ids specifying annotations.

Returns:
List[dict]: Loaded annotation objects.
"""
return self.loadAnns(ids)

def load_cats(self, ids: Union[list, int] = []) -> list:
"""Load categories with the specified ids.

Args:
ids (list | int): Integer ids specifying categories.

Returns:
List[dict]: loaded category objects.
"""
return self.loadCats(ids)

def load_imgs(self, ids: Union[list, int] = []) -> list:
"""Load annotations with the specified ids.

Args:
ids (list): integer ids specifying image.

Returns:
List[dict]: Loaded image objects.
"""
return self.loadImgs(ids)


class COCOPanoptic(COCO):
"""This wrapper is for loading the panoptic style annotation file."""

def createIndex(self) -> None:
"""Create index."""
# create index
print('creating index...')
# anns stores 'segment_id -> annotation'
anns: Dict[int, list] = {}
cats: Dict[int, dict] = {}
imgs: Dict[int, dict] = {}
img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list)
if 'annotations' in self.dataset:
for ann in self.dataset['annotations']:
for seg_ann in ann['segments_info']:
# to match with instance.json
seg_ann['image_id'] = ann['image_id']
img_to_anns[ann['image_id']].append(seg_ann)
# segment_id is not unique in coco dataset orz...
# annotations from different images but
# may have same segment_id
if seg_ann['id'] in anns.keys():
anns[seg_ann['id']].append(seg_ann)
else:
anns[seg_ann['id']] = [seg_ann]

# filter out annotations from other images
img_to_anns_ = defaultdict(list)
for k, v in img_to_anns.items():
img_to_anns_[k] = [x for x in v if x['image_id'] == k]
img_to_anns = img_to_anns_

if 'images' in self.dataset:
for img_info in self.dataset['images']:
img_info['segm_file'] = img_info['file_name'].replace(
'jpg', 'png')
imgs[img_info['id']] = img_info

if 'categories' in self.dataset:
for cat in self.dataset['categories']:
cats[cat['id']] = cat

if 'annotations' in self.dataset and 'categories' in self.dataset:
for ann in self.dataset['annotations']:
for seg_ann in ann['segments_info']:
cat_to_imgs[seg_ann['category_id']].append(ann['image_id'])

print('index created!')

self.anns = anns
self.imgToAnns = img_to_anns
self.catToImgs = cat_to_imgs
self.imgs = imgs
self.cats = cats

def load_anns(self, ids: Union[list, int] = []) -> list:
"""Load annotations with the specified ids.

``self.anns`` is a list of annotation lists instead of a
list of annotations.

Args:
ids (Union[List[int], int]): Integer ids specifying annotations.

Returns:
List: Loaded annotation objects.
"""
anns = []

if isinstance(ids, Sequence):
# self.anns is a list of annotation lists instead of
# a list of annotations
for id in ids:
anns += self.anns[id]
return anns
else:
return self.anns[ids]


# just for the ease of import
COCOeval = _COCOeval
mask_util = _mask_util
29 changes: 16 additions & 13 deletions mmeval/metrics/voc_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List, Optional, Sequence, Tuple, Union

from mmeval.core.base_metric import BaseMetric
from mmeval.utils import is_list_of


def calculate_average_precision(recalls: np.ndarray,
Expand Down Expand Up @@ -199,9 +200,8 @@ class VOCMeanAP(BaseMetric):
Defaults to 4.
drop_class_ap (bool): Whether to drop the class without ground truth
when calculating the average precision for each class.
classwise_result (bool): Whether to return the computed
results of each class.
Defaults to False.
classwise (bool): Whether to return the computed results of each
class. Defaults to False.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.

Examples:
Expand Down Expand Up @@ -241,12 +241,15 @@ def __init__(self,
use_legacy_coordinate: bool = False,
nproc: int = 4,
drop_class_ap: bool = True,
classwise_result: bool = False,
classwise: bool = False,
**kwargs) -> None:
super().__init__(**kwargs)

if isinstance(iou_thrs, float):
iou_thrs = [iou_thrs]
assert is_list_of(iou_thrs, float), \
'`iou_thrs` should be float or a list of float'

self.iou_thrs = iou_thrs

if scale_ranges is None:
Expand All @@ -272,7 +275,7 @@ def __init__(self,
self.nproc = nproc
self.use_legacy_coordinate = use_legacy_coordinate
self.drop_class_ap = drop_class_ap
self.classwise_result = classwise_result
self.classwise = classwise

self.num_iou = len(self.iou_thrs)
self.num_scale = len(self.scale_ranges)
Expand Down Expand Up @@ -321,20 +324,20 @@ def add(self, predictions: Sequence[Dict], groundtruths: Sequence[Dict]) -> None

- bboxes (numpy.ndarray): Shape (M, 4), the ground truth
bounding bboxes of this image, in 'xyxy' foramrt.
- labels (numpy.ndarray): Shape (M, 1), theground truth
- labels (numpy.ndarray): Shape (M, 1), the ground truth
labels of bounding boxes.
- bboxes_ignore (numpy.ndarray): Shape (K, 4), the ground
truth ignored bounding bboxes of this image,
in 'xyxy' foramrt.
- labels_ignore (numpy.ndarray): Shape (K, 1), the ground
truth ignored labels of bounding boxes.
"""
for prediction, label in zip(predictions, groundtruths):
for prediction, groundtruth in zip(predictions, groundtruths):
assert isinstance(prediction, dict), 'The prediciton should be ' \
f'a sequence of dict, but got a sequence of {type(prediction)}.' # noqa: E501
assert isinstance(label, dict), 'The label should be ' \
f'a sequence of dict, but got a sequence of {type(label)}.'
self._results.append((prediction, label))
assert isinstance(groundtruth, dict), 'The label should be ' \
f'a sequence of dict, but got a sequence of {type(groundtruth)}.' # noqa: E501
self._results.append((prediction, groundtruth))

@staticmethod
def _calculate_image_tpfp(
Expand Down Expand Up @@ -562,8 +565,8 @@ def compute_metric(self, results: list) -> dict:
- mAP, the averaged across all IoU thresholds and all class.
- mAP@{IoU}, the mAP of the specified IoU threshold.
- mAP@{scale_range}, the mAP of the specified scale range.
- classwise_result, the evaluation results of each class.
This would be returned if ``self.classwise_result`` is True.
- classwise, the evaluation results of each class.
This would be returned if ``self.classwise`` is True.
"""
predictions, groundtruths = zip(*results)

Expand Down Expand Up @@ -605,7 +608,7 @@ def compute_metric(self, results: list) -> dict:
pool.close()

eval_results = self._aggregate_results(results_per_class)
if self.classwise_result:
if self.classwise:
eval_results['classwise_result'] = results_per_class

return eval_results
Expand Down
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pycocotools
scipy
shapely
Loading