From 0ba93ae75c9d1c8afdd37e82c4c9924462d59574 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 23 Apr 2025 21:36:26 -0500 Subject: [PATCH 1/2] support grayscale and multispectral imagery --- rfdetr/config.py | 3 +++ rfdetr/detr.py | 11 +++++++++-- rfdetr/main.py | 20 ++++++++++++++++++++ tests/test_model.py | 11 +++++++++++ 4 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 tests/test_model.py diff --git a/rfdetr/config.py b/rfdetr/config.py index da57456..8cb0fe9 100644 --- a/rfdetr/config.py +++ b/rfdetr/config.py @@ -12,6 +12,7 @@ class ModelConfig(BaseModel): encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] + num_channels: int = 3 out_feature_indexes: List[int] dec_layers: int = 3 two_stage: bool = True @@ -33,6 +34,7 @@ class ModelConfig(BaseModel): class RFDETRBaseConfig(ModelConfig): encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small" + num_channels: int = 3 hidden_dim: int = 256 sa_nheads: int = 8 ca_nheads: int = 16 @@ -45,6 +47,7 @@ class RFDETRBaseConfig(ModelConfig): class RFDETRLargeConfig(RFDETRBaseConfig): encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base" + num_channels: int = 3 hidden_dim: int = 384 sa_nheads: int = 12 ca_nheads: int = 24 diff --git a/rfdetr/detr.py b/rfdetr/detr.py index 03d1f65..d1899a9 100644 --- a/rfdetr/detr.py +++ b/rfdetr/detr.py @@ -7,6 +7,7 @@ import json import os +from itertools import cycle from collections import defaultdict from logging import getLogger from typing import Union, List @@ -33,6 +34,12 @@ def __init__(self, **kwargs): self.model = self.get_model(self.model_config) self.callbacks = defaultdict(list) + # repeat means and stds for non-rgb images + if self.model_config.num_channels != 3: + self.means = [val for _, val in zip(range(self.model_config.num_channels), cycle(self.means))] + self.stds = [val for _, val in zip(range(self.model_config.num_channels), cycle(self.stds))] + + def maybe_download_pretrain_weights(self): download_pretrain_weights(self.model_config.pretrain_weights) @@ -177,9 +184,9 @@ def predict( "Image has pixel values above 1. Please ensure the image is " "normalized (scaled to [0, 1])." ) - if img.shape[0] != 3: + if img.shape[0] != self.model_config.num_channels: raise ValueError( - f"Invalid image shape. Expected 3 channels (RGB), but got " + f"Invalid image shape. Expected {self.model_config.num_channels} channels, but got " f"{img.shape[0]} channels." ) img_tensor = img diff --git a/rfdetr/main.py b/rfdetr/main.py index d95fe0a..0fc8f45 100644 --- a/rfdetr/main.py +++ b/rfdetr/main.py @@ -33,8 +33,10 @@ import numpy as np import torch +import torch.nn as nn from peft import LoraConfig, get_peft_model from torch.utils.data import DataLoader, DistributedSampler +from timm.models._manipulate import adapt_input_conv import rfdetr.util.misc as utils from rfdetr.datasets import build_dataset, get_coco_api_from_dataset @@ -70,6 +72,16 @@ def download_pretrain_weights(pretrain_weights: str, redownload=False): pretrain_weights, ) +def modify_input_conv(conv: nn.Conv2d, num_channels: int) -> nn.Conv2d: + """Modify the pretrained input conv layer to accept a different number of input channels.""" + new_conv = copy.deepcopy(conv) + new_conv.in_channels = num_channels + new_weight = adapt_input_conv(in_chans=num_channels, conv_weight=conv.weight) + new_conv.weight = torch.nn.Parameter(new_weight) + new_conv.weight.requires_grad = conv.weight.requires_grad + return new_conv + + class Model: def __init__(self, **kwargs): args = populate_args(**kwargs) @@ -126,6 +138,12 @@ def __init__(self, **kwargs): self.model.load_state_dict(checkpoint['model'], strict=False) + # Modify input conv if needed + if args.num_channels != 3: + conv = self.model.backbone[0].encoder.encoder.embeddings.patch_embeddings.projection + self.model.backbone[0].encoder.encoder.embeddings.patch_embeddings.projection = modify_input_conv(conv, args.num_channels) + self.model.backbone[0].encoder.encoder.embeddings.patch_embeddings.num_channels = args.num_channels + if args.backbone_lora: print("Applying LORA to backbone") lora_config = LoraConfig( @@ -814,6 +832,7 @@ def get_args_parser(): def populate_args( # Basic training parameters num_classes=2, + num_channels=3, grad_accum_steps=1, amp=False, lr=1e-4, @@ -936,6 +955,7 @@ def populate_args( ): args = argparse.Namespace( num_classes=num_classes, + num_channels=num_channels, grad_accum_steps=grad_accum_steps, amp=amp, lr=lr, diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..2d32a6e --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,11 @@ +import torch +import pytest +from rfdetr import RFDETRBase, RFDETRLarge + + +@pytest.mark.parametrize("model_class", [RFDETRBase, RFDETRLarge]) +@pytest.mark.parametrize("channels", [1, 4]) +def test_multispectral_support(model_class, channels: int) -> None: + model = model_class(num_channels=channels, device="cpu") + image = torch.zeros(channels, 224, 224).to("cpu") + model.predict(image, threshold=0.5) From bed9c3dfba6d8c68bd862fcbf6133acccce46be8 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 23 Apr 2025 21:41:36 -0500 Subject: [PATCH 2/2] remove new line --- rfdetr/detr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rfdetr/detr.py b/rfdetr/detr.py index d1899a9..95d6a47 100644 --- a/rfdetr/detr.py +++ b/rfdetr/detr.py @@ -39,7 +39,6 @@ def __init__(self, **kwargs): self.means = [val for _, val in zip(range(self.model_config.num_channels), cycle(self.means))] self.stds = [val for _, val in zip(range(self.model_config.num_channels), cycle(self.stds))] - def maybe_download_pretrain_weights(self): download_pretrain_weights(self.model_config.pretrain_weights)