diff --git a/rfdetr/config.py b/rfdetr/config.py index da57456..3ff7ce5 100644 --- a/rfdetr/config.py +++ b/rfdetr/config.py @@ -11,7 +11,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" class ModelConfig(BaseModel): - encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] + encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base", "dinov2_registers_windowed_small"] out_feature_indexes: List[int] dec_layers: int = 3 two_stage: bool = True @@ -32,7 +32,7 @@ class ModelConfig(BaseModel): gradient_checkpointing: bool = False class RFDETRBaseConfig(ModelConfig): - encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small" + encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base", "dinov2_registers_windowed_small"] = "dinov2_windowed_small" hidden_dim: int = 256 sa_nheads: int = 8 ca_nheads: int = 16 @@ -44,7 +44,7 @@ class RFDETRBaseConfig(ModelConfig): pretrain_weights: Optional[str] = "rf-detr-base.pth" class RFDETRLargeConfig(RFDETRBaseConfig): - encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base" + encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base", "dinov2_registers_windowed_small"] = "dinov2_windowed_base" hidden_dim: int = 384 sa_nheads: int = 12 ca_nheads: int = 24