Skip to content

Conversation

@picjul
Copy link

@picjul picjul commented Apr 12, 2025

Description

Saving of images after augmentation. All images belonging to the first batch of the first training epoch are considered. Saving takes place within the output folder.

Annotations on images are described with the integer label of the source dataset.

Need of opencv-python to manage images.

Type of change

  • New feature (non-breaking change which adds functionality)

How has this change been tested, please provide a testcase or example of how you tested the change?

Tested on sample Notebook for training (on Colab), after building the package and installing it:

python -m build
pip install [path-to-generated-wheel]

Any specific deployment considerations

Added opencv-python in pyproject.toml

Docs

No

Sample images

4438eae0-d496-4485-ae67-3fa8e60c7942

bf4f4ad3-4082-4f51-bba1-cf752925bcae

@sctrueew
Copy link

@picjul Hi, thanks for your work. I modified it and saved a 3x3 grid for training and validation batches across 3 batches. Could you please update your code to reflect this change? I also wanted to make a PR. Thanks

import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from rfdetr.util.box_ops import box_cxcywh_to_xyxy

class DatasetGridSaver:
    def __init__(self, data_loader, output_dir, max_batches=3, dataset_type='train'):
        self.data_loader = data_loader
        self.output_dir = output_dir
        self.max_batches = max_batches
        self.dataset_type = dataset_type
        self.save_path = Path(output_dir)
        self.save_path.mkdir(parents=True, exist_ok=True)

    def save_grid(self):
        for batch_idx, (sample, target) in enumerate(self.data_loader):
            if batch_idx >= self.max_batches:
                break
            
            # Create a 3x3 grid for displaying images
            fig, axes = plt.subplots(3, 3, figsize=(12, 12))
            axes = axes.flatten()
            
            # Iterate through each image in the batch
            for sample_index, (single_image, single_target) in enumerate(zip(sample.tensors, target)):
                if sample_index >= 9:  # We only want to display the first 9 images in each batch
                    break

                resized_size = single_target['size']
                
                # Convert image tensor to numpy array for processing
                img_numpy = (np.array(single_image).transpose(1, 2, 0) * 255).copy()

                # Draw bounding boxes and labels on the image
                for (box, label) in zip(single_target['boxes'], single_target['labels']):
                    int_label = int(label)
                    
                    # Convert bounding box from cx,cy,wh format to xyxy
                    b = box_cxcywh_to_xyxy(box)
                    
                    # Scale bounding box coordinates to match the resized image
                    x_min, y_min, x_max, y_max = int(b[0] * resized_size[1]), int(b[1] * resized_size[0]),\
                                                int(b[2] * resized_size[1]), int(b[3] * resized_size[0])
                    
                    # Draw the bounding box on the image
                    cv2.rectangle(img_numpy, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
                    
                    # Add label text near the bounding box
                    text_size = cv2.getTextSize(str(int_label), cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0]
                    text_x, text_y = x_min, y_min - 10
                    cv2.rectangle(img_numpy, (text_x, text_y - text_size[1] - 5), 
                                (text_x + text_size[0] + 5, text_y + 5), (0, 255, 0), -1)  
                    cv2.putText(img_numpy, str(int_label), (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)

                # Plot image in the grid
                ax = axes[sample_index]
                ax.imshow(img_numpy)
                ax.axis('off')  # Hide axis
            
            # Adjust layout and save the figure
            fig.tight_layout()
            grid_path = self.save_path / f"{self.dataset_type}_batch{batch_idx}_grid.jpg"
            plt.savefig(grid_path, dpi=200)
            plt.close()
            
        print(f"✅ Saved {self.dataset_type} grids to: {self.save_path.resolve()}")

and It’s used in main.py

       from rfdetr.util.save_grids import DatasetGridSaver

        print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp'])))

        grid_saver = DatasetGridSaver(data_loader_train, output_dir, max_batches=3, dataset_type='train')
        grid_saver.save_grid()

        grid_saver = DatasetGridSaver(data_loader_val, output_dir, max_batches=3, dataset_type='val')
        grid_saver.save_grid()
        
        print("Start training")

@picjul
Copy link
Author

picjul commented Apr 12, 2025

Committed the improvements suggested by @sctrueew (with documentation and some changes)
Other features:

  • De-normalized images before displaying
  • Handled the case of batch size < 9, removing unnecessary plots
  • Added title to generated images

An example of generated grid below.
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants