diff --git a/training/lightning_module.py b/training/lightning_module.py index 59796f9..926c93f 100644 --- a/training/lightning_module.py +++ b/training/lightning_module.py @@ -388,9 +388,10 @@ def _on_eval_epoch_end_semantic(self, log_prefix, log_per_class=False): block_postfix = self.block_postfix(i) if log_per_class: + original_class_ids_ordered = sorted(list(metric.things)) + sorted(list(metric.stuffs)) for class_idx, iou in enumerate(iou_per_class): self.log( - f"metrics/{log_prefix}_iou_class_{class_idx}{block_postfix}", + f"metrics/{log_prefix}_iou_class_{original_class_ids_ordered[class_idx]}{block_postfix}", iou, ) @@ -440,17 +441,18 @@ def _on_eval_epoch_end_panoptic(self, log_prefix, log_per_class=False): block_postfix = self.block_postfix(i) if log_per_class: + original_class_ids_ordered = sorted(list(metric.things)) + sorted(list(metric.stuffs)) for class_idx in range(len(pq)): self.log( - f"metrics/{log_prefix}_pq_class_{class_idx}{block_postfix}", + f"metrics/{log_prefix}_pq_class_{original_class_ids_ordered[class_idx]}{block_postfix}", pq[class_idx], ) self.log( - f"metrics/{log_prefix}_sq_class_{class_idx}{block_postfix}", + f"metrics/{log_prefix}_sq_class_{original_class_ids_ordered[class_idx]}{block_postfix}", sq[class_idx], ) self.log( - f"metrics/{log_prefix}_rq_class_{class_idx}{block_postfix}", + f"metrics/{log_prefix}_rq_class_{original_class_ids_ordered[class_idx]}{block_postfix}", rq[class_idx], )