- 
                Notifications
    You must be signed in to change notification settings 
- Fork 7
[Feature] New Matcher Class: Voronoi #229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
391be00
              ae9d1b8
              e45b32c
              2519be7
              b53e639
              a61873b
              d91b21f
              36c66d8
              d170cff
              f0dd198
              96c538d
              4990799
              34d9bba
              72bb082
              9a86260
              98d60bf
              222c6b4
              30c7918
              adda2c2
              16bc50f
              8725f44
              d9f3550
              bcd8c8b
              6ceed04
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -3,11 +3,13 @@ | |
| from typing import Optional, Tuple, List | ||
|  | ||
| import numpy as np | ||
| from scipy.ndimage import distance_transform_edt | ||
|  | ||
| from panoptica._functionals import ( | ||
| _calc_matching_metric_of_overlapping_labels, | ||
| _calc_matching_metric_of_overlapping_partlabels, | ||
| _map_labels, | ||
| _connected_components, | ||
| ) | ||
| from panoptica.metrics import Metric | ||
| from panoptica.utils.processing_pair import ( | ||
|  | @@ -17,6 +19,7 @@ | |
| from panoptica.utils.instancelabelmap import InstanceLabelMap | ||
| from panoptica.utils.config import SupportsConfig | ||
| from panoptica.utils.label_group import LabelGroup, LabelPartGroup | ||
| from panoptica.utils.constants import CCABackend | ||
|  | ||
|  | ||
| @dataclass | ||
|  | @@ -493,3 +496,118 @@ def _yaml_repr(cls, node) -> dict: | |
| "matching_metric": node._matching_metric, | ||
| "matching_threshold": node._matching_threshold, | ||
| } | ||
|  | ||
|  | ||
| class RegionBasedMatching(InstanceMatchingAlgorithm): | ||
| """ | ||
| Instance matching algorithm that performs region-based matching using spatial distance. | ||
|  | ||
| This method assigns prediction instances to ground truth regions based on spatial proximity | ||
| rather than traditional overlap-based metrics. It uses connected components and distance | ||
| transforms to create region assignments. | ||
|  | ||
| Note: This matching method does not produce traditional count metrics (TP/FP/FN) as it | ||
| assigns all predictions to regions. Count metrics will be set to NaN. | ||
|  | ||
| Attributes: | ||
| cca_backend (CCABackend): Backend for connected component analysis. | ||
| """ | ||
|  | ||
| def __init__( | ||
| self, | ||
| cca_backend: CCABackend = CCABackend.scipy, | ||
| ) -> None: | ||
| """ | ||
| Initialize the RegionBasedMatching instance. | ||
|  | ||
| Args: | ||
| cca_backend (CCABackend): Backend for connected component analysis. | ||
| """ | ||
| self._cca_backend = cca_backend | ||
|  | ||
| def _get_gt_regions(self, gt: np.ndarray) -> Tuple[np.ndarray, int]: | ||
| """ | ||
| Get ground truth regions using connected components and distance transforms. | ||
|  | ||
| Args: | ||
| gt: Ground truth array | ||
|  | ||
| Returns: | ||
| Tuple of (region_map, num_features) where region_map assigns each pixel | ||
| to the closest ground truth region. | ||
| """ | ||
| # Step 1: Connected Components | ||
| labeled_array, num_features = _connected_components(gt, self._cca_backend) | ||
|  | ||
| # Step 2: Compute distance transform for each region | ||
| distance_map = np.full(gt.shape, np.inf, dtype=np.float32) | ||
| region_map = np.zeros(gt.shape, dtype=np.int32) | ||
|  | ||
| for region_label in range(1, num_features + 1): | ||
| # Create region mask | ||
| region_mask = labeled_array == region_label | ||
|  | ||
| # Compute distance transform | ||
| distance = distance_transform_edt(~region_mask) | ||
| 
     | ||
|  | ||
| # Update pixels where this region is closer | ||
| update_mask = distance < distance_map | ||
| distance_map[update_mask] = distance[update_mask] | ||
| region_map[update_mask] = region_label | ||
|  | ||
| return region_map, num_features | ||
|  | ||
| def _match_instances( | ||
| self, | ||
| unmatched_instance_pair: UnmatchedInstancePair, | ||
| context: Optional[MatchingContext] = None, | ||
| **kwargs, | ||
| ) -> InstanceLabelMap: | ||
| """ | ||
| Perform region-based instance matching. | ||
|  | ||
| Args: | ||
| unmatched_instance_pair (UnmatchedInstancePair): The unmatched instance pair to be matched. | ||
| context (Optional[MatchingContext]): The matching context. | ||
| **kwargs: Additional keyword arguments. | ||
|  | ||
| Returns: | ||
| InstanceLabelMap: The result of the region-based matching. | ||
| """ | ||
| pred_arr = unmatched_instance_pair.prediction_arr | ||
| ref_arr = unmatched_instance_pair.reference_arr | ||
| pred_labels = unmatched_instance_pair.pred_labels | ||
|  | ||
| labelmap = InstanceLabelMap() | ||
|  | ||
| if len(pred_labels) == 0: | ||
| return labelmap | ||
|  | ||
| # Get ground truth regions | ||
| region_map, num_features = self._get_gt_regions(ref_arr) | ||
|  | ||
| # For each prediction instance, find which ground truth region it belongs to | ||
| for pred_label in pred_labels: | ||
| pred_mask = pred_arr == pred_label | ||
|  | ||
| # Find the most common region assignment for this prediction instance | ||
| pred_regions = region_map[pred_mask] | ||
|  | ||
| # Remove background (region 0) | ||
| pred_regions = pred_regions[pred_regions > 0] | ||
|  | ||
| if len(pred_regions) > 0: | ||
| # Assign to the most common region | ||
| unique_regions, counts = np.unique(pred_regions, return_counts=True) | ||
| most_common_region = unique_regions[np.argmax(counts)] | ||
|  | ||
| # Add to labelmap | ||
| labelmap.add_labelmap_entry(int(pred_label), int(most_common_region)) | ||
|  | ||
| return labelmap | ||
|  | ||
| @classmethod | ||
| def _yaml_repr(cls, node) -> dict: | ||
| return { | ||
| "cca_backend": node._cca_backend, | ||
| } | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -442,6 +442,11 @@ def panoptic_evaluate( | |
| else instance_metadata["original_num_refs"] | ||
| ) | ||
|  | ||
| # For region-based matching, set TP to NaN since it doesn't use traditional counting | ||
| tp_value = processing_pair.tp | ||
| if instance_matcher.__class__.__name__ == "RegionBasedMatching": | ||
| tp_value = np.nan | ||
|  | ||
| 
      Comment on lines
    
      +445
     to 
      +449
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ahh I see. you create more labelmap entries and thus indirectly create the region-based loop we discussed. I much rather would say the matcher just takes care of the rest by calling the remainder of the pipeline itself (no need to create labelamp entries, that all is practically useless computation) and then the matcher itself can set those values to nan after the corresponding result objects have been created. You agree? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. This seemed the most minimal approach in hindsight and thats why I went ahead with it. How do you feel about this then: git diff --no-color panoptica/instance_matcher.py panoptica/utils/processing_pair.py panoptica/instance_evaluator.py panoptica/panoptica_evaluator.py | cat
diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py
index 8b5727b..fb5707d 100644
--- a/panoptica/instance_evaluator.py
+++ b/panoptica/instance_evaluator.py
@@ -30,7 +30,11 @@ def evaluate_matched_instance(
         ], "decision metric not contained in eval_metrics"
         assert decision_threshold is not None, "decision metric set but no threshold"
     # Initialize variables for True Positives (tp)
-    tp = len(matched_instance_pair.matched_instances)
+    # For non-traditional counting (e.g., region-based matching), tp should be NaN
+    if hasattr(matched_instance_pair, 'use_traditional_counting') and not matched_instance_pair.use_traditional_counting:
+        tp = float('nan')
+    else:
+        tp = len(matched_instance_pair.matched_instances)
     score_dict: dict[Metric, list[float]] = {m: [] for m in eval_metrics}
 
     reference_arr, prediction_arr = (
diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py
index 0eb4afc..8df679a 100644
--- a/panoptica/instance_matcher.py
+++ b/panoptica/instance_matcher.py
@@ -118,7 +118,7 @@ class InstanceMatchingAlgorithm(SupportsConfig, metaclass=ABCMeta):
             **kwargs,
         )
 
-        return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap)
+        return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap, use_traditional_counting=True)
 
     def _calculate_matching_metric_pairs(
         self,
@@ -164,7 +164,9 @@ class InstanceMatchingAlgorithm(SupportsConfig, metaclass=ABCMeta):
 
 
 def map_instance_labels(
-    processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap
+    processing_pair: UnmatchedInstancePair, 
+    labelmap: InstanceLabelMap,
+    use_traditional_counting: bool = True
 ) -> MatchedInstancePair:
     """
     Map instance labels based on the provided labelmap and create a MatchedInstancePair.
@@ -198,6 +200,7 @@ def map_instance_labels(
     matched_instance_pair = MatchedInstancePair(
         prediction_arr=prediction_arr_relabeled,
         reference_arr=processing_pair.reference_arr,
+        use_traditional_counting=use_traditional_counting,
     )
     return matched_instance_pair
 
@@ -524,6 +527,39 @@ class RegionBasedMatching(InstanceMatchingAlgorithm):
             cca_backend (CCABackend): Backend for connected component analysis.
         """
         self._cca_backend = cca_backend
+        
+    def match_instances(
+        self,
+        unmatched_instance_pair: UnmatchedInstancePair,
+        label_group=None,
+        num_ref_labels=None,
+        processing_pair_orig_shape=None,
+        **kwargs,
+    ) -> MatchedInstancePair:
+        """
+        Override to set use_traditional_counting=False for region-based matching.
+        """
+        # Create context if needed
+        context = None
+        if (
+            label_group is not None
+            or num_ref_labels is not None
+            or processing_pair_orig_shape is not None
+        ):
+            context = MatchingContext(
+                label_group=label_group,
+                num_ref_labels=num_ref_labels,
+                processing_pair_orig_shape=processing_pair_orig_shape,
+            )
+
+        instance_labelmap = self._match_instances(
+            unmatched_instance_pair,
+            context,
+            **kwargs,
+        )
+
+        # Use use_traditional_counting=False since region-based matching doesn't use traditional TP semantics
+        return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap, use_traditional_counting=False)
 
     def _get_gt_regions(self, gt: np.ndarray) -> Tuple[np.ndarray, int]:
         """
diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py
index 6b2e2cf..2c8feee 100644
--- a/panoptica/panoptica_evaluator.py
+++ b/panoptica/panoptica_evaluator.py
@@ -442,10 +442,8 @@ def panoptic_evaluate(
             else instance_metadata["original_num_refs"]
         )
 
-        # For region-based matching, set TP to NaN since it doesn't use traditional counting
+        # Use tp from processing_pair (already handles NaN for non-traditional counting methods)
         tp_value = processing_pair.tp
-        if instance_matcher.__class__.__name__ == "RegionBasedMatching":
-            tp_value = np.nan
 
         processing_pair = PanopticaResult(
             reference_arr=processing_pair.reference_arr,
diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py
index fed0b99..a061269 100644
--- a/panoptica/utils/processing_pair.py
+++ b/panoptica/utils/processing_pair.py
@@ -322,11 +322,13 @@ class MatchedInstancePair(_ProcessingPairInstanced):
         missed_reference_labels (list[int]): Reference labels with no matching prediction.
         missed_prediction_labels (list[int]): Prediction labels with no matching reference.
         matched_instances (list[int]): Labels matched between prediction and reference arrays.
+        use_traditional_counting (bool): Whether this matching uses traditional TP/FP/FN semantics.
     """
 
     missed_reference_labels: list[int]
     missed_prediction_labels: list[int]
     matched_instances: list[int]
+    use_traditional_counting: bool
 
     def __init__(
         self,
@@ -337,6 +339,7 @@ class MatchedInstancePair(_ProcessingPairInstanced):
         matched_instances: list[int] | None = None,
         n_prediction_instance: int | None = None,
         n_reference_instance: int | None = None,
+        use_traditional_counting: bool = True,
     ) -> None:
         """Initializes a MatchedInstancePair
 
@@ -348,6 +351,7 @@ class MatchedInstancePair(_ProcessingPairInstanced):
             matched_instances (int | None, optional): matched instances labels, i.e. unique matched labels in both maps. Defaults to None.
             n_prediction_instance (int | None, optional): Number of prediction instances. Defaults to None.
             n_reference_instance (int | None, optional): Number of reference instances. Defaults to None.
+            use_traditional_counting (bool, optional): Whether this matching uses traditional TP/FP/FN semantics. Defaults to True.
 
             For each argument: If none, will calculate on initialization.
         """
@@ -360,6 +364,7 @@ class MatchedInstancePair(_ProcessingPairInstanced):
         if matched_instances is None:
             matched_instances = [i for i in self.pred_labels if i in self.ref_labels]
         self.matched_instances = matched_instances
+        self.use_traditional_counting = use_traditional_counting
 
         if missed_reference_labels is None:
             missed_reference_labels = list(
@@ -389,6 +394,7 @@ class MatchedInstancePair(_ProcessingPairInstanced):
             missed_reference_labels=self.missed_reference_labels,
             missed_prediction_labels=self.missed_prediction_labels,
             matched_instances=self.matched_instances,
+            use_traditional_counting=self.use_traditional_counting,
         ) | ||
| processing_pair = PanopticaResult( | ||
| reference_arr=processing_pair.reference_arr, | ||
| prediction_arr=processing_pair.prediction_arr, | ||
|  | @@ -450,7 +455,7 @@ def panoptic_evaluate( | |
| num_ref_instances=final_num_ref_instances, | ||
| num_ref_labels=instance_metadata["num_ref_labels"], | ||
| label_group=label_group, | ||
| tp=processing_pair.tp, | ||
| tp=tp_value, | ||
| list_metrics=processing_pair.list_metrics, | ||
| global_metrics=global_metrics, | ||
| edge_case_handler=edge_case_handler, | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The benchmark_panoptica_cupy function lacks GPU memory cleanup unlike benchmark_cupy. Consider adding memory pool cleanup:
cp.get_default_memory_pool().free_all_blocks()after the timing to ensure consistent memory usage across benchmarks.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1