Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
branches: ["main"]

jobs:
build:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand All @@ -32,7 +32,7 @@ jobs:
python -m pip install poetry
- name: Install dependencies
run: |
python -m poetry install
python -m poetry install --extras test
- name: Test with pytest and create coverage report
run: |
python -m poetry run coverage run --source=panoptica -m pytest
Expand All @@ -43,3 +43,35 @@ jobs:
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}

test-cuda:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"

- name: Configure poetry
run: |
python -m pip install --upgrade pip
python -m pip install poetry
- name: Install dependencies with GPU extras
run: |
python -m poetry install --extras "gpu test" || python -m poetry install --extras test
- name: Test CUDA functionality (CPU fallback)
run: |
python -m poetry run pytest unit_tests/test_cupy_connected_components.py -v
- name: Upload coverage results to Codecov (Only on merge to main)
# Only upload to Codecov after a merge to the main branch
if: github.ref == 'refs/heads/main' && github.event_name == 'push'
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
73 changes: 73 additions & 0 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@
# scipy needs to be installed to run this benchmark, we use cc3d as it is quicker for 3D data
from scipy import ndimage

# Try to import cupy for GPU acceleration
try:
import cupy as cp
from cupyx.scipy import ndimage as cp_ndimage

CUPY_AVAILABLE = True
except ImportError:
CUPY_AVAILABLE = False
print("CuPy not available. GPU benchmarks will be skipped.")


def generate_random_binary_mask(size: Tuple[int, int, Union[int, None]]) -> np.ndarray:
"""
Expand Down Expand Up @@ -64,6 +74,64 @@ def label_cc3d():
return cc3d_time


def benchmark_cupy(mask: np.ndarray):
"""
Benchmark the performance of cupy.ndimage.label for connected component labeling on GPU.

Args:
mask (np.ndarray): Binary mask to label.

Returns:
float: Time taken to label the mask in seconds, or None if CuPy is not available.
"""
if not CUPY_AVAILABLE:
return None

# Transfer data to GPU
mask_gpu = cp.asarray(mask)

# Warmup phase
for _ in range(3):
cp_ndimage.label(mask_gpu)
cp.cuda.Stream.null.synchronize()

def label_cupy():
cp_ndimage.label(mask_gpu)
cp.cuda.Stream.null.synchronize() # Ensure GPU computation is complete

cupy_time = timeit.timeit(label_cupy, number=10)

# Clean up GPU memory
del mask_gpu
cp.get_default_memory_pool().free_all_blocks()

return cupy_time


def benchmark_panoptica_cupy(mask: np.ndarray):
"""
Benchmark the performance of panoptica's CuPy backend for connected component labeling.

Args:
mask (np.ndarray): Binary mask to label.

Returns:
float: Time taken to label the mask in seconds, or None if CuPy is not available.
"""
if not CUPY_AVAILABLE:
return None

from panoptica._functionals import _connected_components
from panoptica.utils.constants import CCABackend

def label_panoptica_cupy():
_connected_components(mask, CCABackend.cupy)

panoptica_cupy_time = timeit.timeit(label_panoptica_cupy, number=10)

Copy link

Copilot AI Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmark_panoptica_cupy function lacks GPU memory cleanup unlike benchmark_cupy. Consider adding memory pool cleanup: cp.get_default_memory_pool().free_all_blocks() after the timing to ensure consistent memory usage across benchmarks.

Suggested change
# Clean up GPU memory
cp.get_default_memory_pool().free_all_blocks()

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

return panoptica_cupy_time


def run_benchmarks(volume_sizes: Tuple[Tuple[int, int, Union[int, None]]]) -> None:
"""
Run benchmark tests for connected component labeling with different volume sizes.
Expand All @@ -80,10 +148,15 @@ def run_benchmarks(volume_sizes: Tuple[Tuple[int, int, Union[int, None]]]) -> No

scipy_time = benchmark_scipy(mask)
cc3d_time = benchmark_cc3d(mask)
cupy_time = benchmark_cupy(mask)

print(f"Volume Size: {size}")
print(f"Scipy Time: {scipy_time:.4f} seconds")
print(f"CC3D Time: {cc3d_time:.4f} seconds")
if cupy_time is not None:
print(f"CuPy Time: {cupy_time:.4f} seconds")
else:
print("CuPy Time: Not available")
print()


Expand Down
6 changes: 5 additions & 1 deletion panoptica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
ConnectedComponentsInstanceApproximator,
CCABackend,
)
from panoptica.instance_matcher import NaiveThresholdMatching, MaxBipartiteMatching
from panoptica.instance_matcher import (
NaiveThresholdMatching,
MaxBipartiteMatching,
RegionBasedMatching,
)
from panoptica.panoptica_statistics import Panoptica_Statistic, ValueSummary
from panoptica.panoptica_aggregator import Panoptica_Aggregator
from panoptica.panoptica_evaluator import Panoptica_Evaluator
Expand Down
13 changes: 13 additions & 0 deletions panoptica/_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ def _connected_components(
from scipy.ndimage import label

cc_arr, n_instances = label(array)
elif cca_backend == CCABackend.cupy:
try:
import cupy as cp
from cupyx.scipy.ndimage import label as cp_label

array_gpu = cp.asarray(array)
cc_arr, n_instances = cp_label(array_gpu)
cc_arr = cp.asnumpy(cc_arr)
except ImportError:
raise ImportError(
"CuPy is not installed. Please install CuPy to use the GPU backend. "
"You can install it using: pip install cupy-cuda11x or cupy-cuda12x depending on your CUDA version."
)
else:
raise NotImplementedError(cca_backend)

Expand Down
118 changes: 118 additions & 0 deletions panoptica/instance_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from typing import Optional, Tuple, List

import numpy as np
from scipy.ndimage import distance_transform_edt

from panoptica._functionals import (
_calc_matching_metric_of_overlapping_labels,
_calc_matching_metric_of_overlapping_partlabels,
_map_labels,
_connected_components,
)
from panoptica.metrics import Metric
from panoptica.utils.processing_pair import (
Expand All @@ -17,6 +19,7 @@
from panoptica.utils.instancelabelmap import InstanceLabelMap
from panoptica.utils.config import SupportsConfig
from panoptica.utils.label_group import LabelGroup, LabelPartGroup
from panoptica.utils.constants import CCABackend


@dataclass
Expand Down Expand Up @@ -493,3 +496,118 @@ def _yaml_repr(cls, node) -> dict:
"matching_metric": node._matching_metric,
"matching_threshold": node._matching_threshold,
}


class RegionBasedMatching(InstanceMatchingAlgorithm):
"""
Instance matching algorithm that performs region-based matching using spatial distance.

This method assigns prediction instances to ground truth regions based on spatial proximity
rather than traditional overlap-based metrics. It uses connected components and distance
transforms to create region assignments.

Note: This matching method does not produce traditional count metrics (TP/FP/FN) as it
assigns all predictions to regions. Count metrics will be set to NaN.

Attributes:
cca_backend (CCABackend): Backend for connected component analysis.
"""

def __init__(
self,
cca_backend: CCABackend = CCABackend.scipy,
) -> None:
"""
Initialize the RegionBasedMatching instance.

Args:
cca_backend (CCABackend): Backend for connected component analysis.
"""
self._cca_backend = cca_backend

def _get_gt_regions(self, gt: np.ndarray) -> Tuple[np.ndarray, int]:
"""
Get ground truth regions using connected components and distance transforms.

Args:
gt: Ground truth array

Returns:
Tuple of (region_map, num_features) where region_map assigns each pixel
to the closest ground truth region.
"""
# Step 1: Connected Components
labeled_array, num_features = _connected_components(gt, self._cca_backend)

# Step 2: Compute distance transform for each region
distance_map = np.full(gt.shape, np.inf, dtype=np.float32)
region_map = np.zeros(gt.shape, dtype=np.int32)

for region_label in range(1, num_features + 1):
# Create region mask
region_mask = labeled_array == region_label

# Compute distance transform
distance = distance_transform_edt(~region_mask)
Copy link

Copilot AI Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computing distance transforms for all regions sequentially can be inefficient for large numbers of regions. Consider vectorizing this operation or using parallel processing for better performance with many ground truth regions.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would move this to a functionals function and import that here.
Otherwise, I agree with copilot, this can be made faster I think.


# Update pixels where this region is closer
update_mask = distance < distance_map
distance_map[update_mask] = distance[update_mask]
region_map[update_mask] = region_label

return region_map, num_features

def _match_instances(
self,
unmatched_instance_pair: UnmatchedInstancePair,
context: Optional[MatchingContext] = None,
**kwargs,
) -> InstanceLabelMap:
"""
Perform region-based instance matching.

Args:
unmatched_instance_pair (UnmatchedInstancePair): The unmatched instance pair to be matched.
context (Optional[MatchingContext]): The matching context.
**kwargs: Additional keyword arguments.

Returns:
InstanceLabelMap: The result of the region-based matching.
"""
pred_arr = unmatched_instance_pair.prediction_arr
ref_arr = unmatched_instance_pair.reference_arr
pred_labels = unmatched_instance_pair.pred_labels

labelmap = InstanceLabelMap()

if len(pred_labels) == 0:
return labelmap

# Get ground truth regions
region_map, num_features = self._get_gt_regions(ref_arr)

# For each prediction instance, find which ground truth region it belongs to
for pred_label in pred_labels:
pred_mask = pred_arr == pred_label

# Find the most common region assignment for this prediction instance
pred_regions = region_map[pred_mask]

# Remove background (region 0)
pred_regions = pred_regions[pred_regions > 0]

if len(pred_regions) > 0:
# Assign to the most common region
unique_regions, counts = np.unique(pred_regions, return_counts=True)
most_common_region = unique_regions[np.argmax(counts)]

# Add to labelmap
labelmap.add_labelmap_entry(int(pred_label), int(most_common_region))

return labelmap

@classmethod
def _yaml_repr(cls, node) -> dict:
return {
"cca_backend": node._cca_backend,
}
7 changes: 6 additions & 1 deletion panoptica/panoptica_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ def panoptic_evaluate(
else instance_metadata["original_num_refs"]
)

# For region-based matching, set TP to NaN since it doesn't use traditional counting
tp_value = processing_pair.tp
if instance_matcher.__class__.__name__ == "RegionBasedMatching":
tp_value = np.nan

Comment on lines +445 to +449
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh I see. you create more labelmap entries and thus indirectly create the region-based loop we discussed.
Although that is fine, I really don't like something hacky like this thing.

I much rather would say the matcher just takes care of the rest by calling the remainder of the pipeline itself (no need to create labelamp entries, that all is practically useless computation) and then the matcher itself can set those values to nan after the corresponding result objects have been created. You agree?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. This seemed the most minimal approach in hindsight and thats why I went ahead with it. How do you feel about this then:

git diff --no-color panoptica/instance_matcher.py panoptica/utils/processing_pair.py panoptica/instance_evaluator.py panoptica/panoptica_evaluator.py | cat
diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py
index 8b5727b..fb5707d 100644
--- a/panoptica/instance_evaluator.py
+++ b/panoptica/instance_evaluator.py
@@ -30,7 +30,11 @@ def evaluate_matched_instance(
         ], "decision metric not contained in eval_metrics"
         assert decision_threshold is not None, "decision metric set but no threshold"
     # Initialize variables for True Positives (tp)
-    tp = len(matched_instance_pair.matched_instances)
+    # For non-traditional counting (e.g., region-based matching), tp should be NaN
+    if hasattr(matched_instance_pair, 'use_traditional_counting') and not matched_instance_pair.use_traditional_counting:
+        tp = float('nan')
+    else:
+        tp = len(matched_instance_pair.matched_instances)
     score_dict: dict[Metric, list[float]] = {m: [] for m in eval_metrics}
 
     reference_arr, prediction_arr = (
diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py
index 0eb4afc..8df679a 100644
--- a/panoptica/instance_matcher.py
+++ b/panoptica/instance_matcher.py
@@ -118,7 +118,7 @@ class InstanceMatchingAlgorithm(SupportsConfig, metaclass=ABCMeta):
             **kwargs,
         )
 
-        return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap)
+        return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap, use_traditional_counting=True)
 
     def _calculate_matching_metric_pairs(
         self,
@@ -164,7 +164,9 @@ class InstanceMatchingAlgorithm(SupportsConfig, metaclass=ABCMeta):
 
 
 def map_instance_labels(
-    processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap
+    processing_pair: UnmatchedInstancePair, 
+    labelmap: InstanceLabelMap,
+    use_traditional_counting: bool = True
 ) -> MatchedInstancePair:
     """
     Map instance labels based on the provided labelmap and create a MatchedInstancePair.
@@ -198,6 +200,7 @@ def map_instance_labels(
     matched_instance_pair = MatchedInstancePair(
         prediction_arr=prediction_arr_relabeled,
         reference_arr=processing_pair.reference_arr,
+        use_traditional_counting=use_traditional_counting,
     )
     return matched_instance_pair
 
@@ -524,6 +527,39 @@ class RegionBasedMatching(InstanceMatchingAlgorithm):
             cca_backend (CCABackend): Backend for connected component analysis.
         """
         self._cca_backend = cca_backend
+        
+    def match_instances(
+        self,
+        unmatched_instance_pair: UnmatchedInstancePair,
+        label_group=None,
+        num_ref_labels=None,
+        processing_pair_orig_shape=None,
+        **kwargs,
+    ) -> MatchedInstancePair:
+        """
+        Override to set use_traditional_counting=False for region-based matching.
+        """
+        # Create context if needed
+        context = None
+        if (
+            label_group is not None
+            or num_ref_labels is not None
+            or processing_pair_orig_shape is not None
+        ):
+            context = MatchingContext(
+                label_group=label_group,
+                num_ref_labels=num_ref_labels,
+                processing_pair_orig_shape=processing_pair_orig_shape,
+            )
+
+        instance_labelmap = self._match_instances(
+            unmatched_instance_pair,
+            context,
+            **kwargs,
+        )
+
+        # Use use_traditional_counting=False since region-based matching doesn't use traditional TP semantics
+        return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap, use_traditional_counting=False)
 
     def _get_gt_regions(self, gt: np.ndarray) -> Tuple[np.ndarray, int]:
         """
diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py
index 6b2e2cf..2c8feee 100644
--- a/panoptica/panoptica_evaluator.py
+++ b/panoptica/panoptica_evaluator.py
@@ -442,10 +442,8 @@ def panoptic_evaluate(
             else instance_metadata["original_num_refs"]
         )
 
-        # For region-based matching, set TP to NaN since it doesn't use traditional counting
+        # Use tp from processing_pair (already handles NaN for non-traditional counting methods)
         tp_value = processing_pair.tp
-        if instance_matcher.__class__.__name__ == "RegionBasedMatching":
-            tp_value = np.nan
 
         processing_pair = PanopticaResult(
             reference_arr=processing_pair.reference_arr,
diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py
index fed0b99..a061269 100644
--- a/panoptica/utils/processing_pair.py
+++ b/panoptica/utils/processing_pair.py
@@ -322,11 +322,13 @@ class MatchedInstancePair(_ProcessingPairInstanced):
         missed_reference_labels (list[int]): Reference labels with no matching prediction.
         missed_prediction_labels (list[int]): Prediction labels with no matching reference.
         matched_instances (list[int]): Labels matched between prediction and reference arrays.
+        use_traditional_counting (bool): Whether this matching uses traditional TP/FP/FN semantics.
     """
 
     missed_reference_labels: list[int]
     missed_prediction_labels: list[int]
     matched_instances: list[int]
+    use_traditional_counting: bool
 
     def __init__(
         self,
@@ -337,6 +339,7 @@ class MatchedInstancePair(_ProcessingPairInstanced):
         matched_instances: list[int] | None = None,
         n_prediction_instance: int | None = None,
         n_reference_instance: int | None = None,
+        use_traditional_counting: bool = True,
     ) -> None:
         """Initializes a MatchedInstancePair
 
@@ -348,6 +351,7 @@ class MatchedInstancePair(_ProcessingPairInstanced):
             matched_instances (int | None, optional): matched instances labels, i.e. unique matched labels in both maps. Defaults to None.
             n_prediction_instance (int | None, optional): Number of prediction instances. Defaults to None.
             n_reference_instance (int | None, optional): Number of reference instances. Defaults to None.
+            use_traditional_counting (bool, optional): Whether this matching uses traditional TP/FP/FN semantics. Defaults to True.
 
             For each argument: If none, will calculate on initialization.
         """
@@ -360,6 +364,7 @@ class MatchedInstancePair(_ProcessingPairInstanced):
         if matched_instances is None:
             matched_instances = [i for i in self.pred_labels if i in self.ref_labels]
         self.matched_instances = matched_instances
+        self.use_traditional_counting = use_traditional_counting
 
         if missed_reference_labels is None:
             missed_reference_labels = list(
@@ -389,6 +394,7 @@ class MatchedInstancePair(_ProcessingPairInstanced):
             missed_reference_labels=self.missed_reference_labels,
             missed_prediction_labels=self.missed_prediction_labels,
             matched_instances=self.matched_instances,
+            use_traditional_counting=self.use_traditional_counting,
         )

processing_pair = PanopticaResult(
reference_arr=processing_pair.reference_arr,
prediction_arr=processing_pair.prediction_arr,
Expand All @@ -450,7 +455,7 @@ def panoptic_evaluate(
num_ref_instances=final_num_ref_instances,
num_ref_labels=instance_metadata["num_ref_labels"],
label_group=label_group,
tp=processing_pair.tp,
tp=tp_value,
list_metrics=processing_pair.list_metrics,
global_metrics=global_metrics,
edge_case_handler=edge_case_handler,
Expand Down
Loading
Loading