Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

class ModelConfig(BaseModel):
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"]
num_channels: int = 3
out_feature_indexes: List[int]
dec_layers: int = 3
two_stage: bool = True
Expand All @@ -33,6 +34,7 @@ class ModelConfig(BaseModel):

class RFDETRBaseConfig(ModelConfig):
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small"
num_channels: int = 3
hidden_dim: int = 256
sa_nheads: int = 8
ca_nheads: int = 16
Expand All @@ -45,6 +47,7 @@ class RFDETRBaseConfig(ModelConfig):

class RFDETRLargeConfig(RFDETRBaseConfig):
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base"
num_channels: int = 3
hidden_dim: int = 384
sa_nheads: int = 12
ca_nheads: int = 24
Expand Down
10 changes: 8 additions & 2 deletions rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import json
import os
from itertools import cycle
from collections import defaultdict
from logging import getLogger
from typing import Union, List
Expand All @@ -33,6 +34,11 @@ def __init__(self, **kwargs):
self.model = self.get_model(self.model_config)
self.callbacks = defaultdict(list)

# repeat means and stds for non-rgb images
if self.model_config.num_channels != 3:
self.means = [val for _, val in zip(range(self.model_config.num_channels), cycle(self.means))]
self.stds = [val for _, val in zip(range(self.model_config.num_channels), cycle(self.stds))]

def maybe_download_pretrain_weights(self):
download_pretrain_weights(self.model_config.pretrain_weights)

Expand Down Expand Up @@ -177,9 +183,9 @@ def predict(
"Image has pixel values above 1. Please ensure the image is "
"normalized (scaled to [0, 1])."
)
if img.shape[0] != 3:
if img.shape[0] != self.model_config.num_channels:
raise ValueError(
f"Invalid image shape. Expected 3 channels (RGB), but got "
f"Invalid image shape. Expected {self.model_config.num_channels} channels, but got "
f"{img.shape[0]} channels."
)
img_tensor = img
Expand Down
20 changes: 20 additions & 0 deletions rfdetr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@

import numpy as np
import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader, DistributedSampler
from timm.models._manipulate import adapt_input_conv

import rfdetr.util.misc as utils
from rfdetr.datasets import build_dataset, get_coco_api_from_dataset
Expand Down Expand Up @@ -70,6 +72,16 @@ def download_pretrain_weights(pretrain_weights: str, redownload=False):
pretrain_weights,
)

def modify_input_conv(conv: nn.Conv2d, num_channels: int) -> nn.Conv2d:
"""Modify the pretrained input conv layer to accept a different number of input channels."""
new_conv = copy.deepcopy(conv)
new_conv.in_channels = num_channels
new_weight = adapt_input_conv(in_chans=num_channels, conv_weight=conv.weight)
new_conv.weight = torch.nn.Parameter(new_weight)
new_conv.weight.requires_grad = conv.weight.requires_grad
return new_conv


class Model:
def __init__(self, **kwargs):
args = populate_args(**kwargs)
Expand Down Expand Up @@ -126,6 +138,12 @@ def __init__(self, **kwargs):

self.model.load_state_dict(checkpoint['model'], strict=False)

# Modify input conv if needed
if args.num_channels != 3:
conv = self.model.backbone[0].encoder.encoder.embeddings.patch_embeddings.projection
self.model.backbone[0].encoder.encoder.embeddings.patch_embeddings.projection = modify_input_conv(conv, args.num_channels)
self.model.backbone[0].encoder.encoder.embeddings.patch_embeddings.num_channels = args.num_channels

if args.backbone_lora:
print("Applying LORA to backbone")
lora_config = LoraConfig(
Expand Down Expand Up @@ -814,6 +832,7 @@ def get_args_parser():
def populate_args(
# Basic training parameters
num_classes=2,
num_channels=3,
grad_accum_steps=1,
amp=False,
lr=1e-4,
Expand Down Expand Up @@ -936,6 +955,7 @@ def populate_args(
):
args = argparse.Namespace(
num_classes=num_classes,
num_channels=num_channels,
grad_accum_steps=grad_accum_steps,
amp=amp,
lr=lr,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch
import pytest
from rfdetr import RFDETRBase, RFDETRLarge


@pytest.mark.parametrize("model_class", [RFDETRBase, RFDETRLarge])
@pytest.mark.parametrize("channels", [1, 4])
def test_multispectral_support(model_class, channels: int) -> None:
model = model_class(num_channels=channels, device="cpu")
image = torch.zeros(channels, 224, 224).to("cpu")
model.predict(image, threshold=0.5)