Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
360 changes: 360 additions & 0 deletions 0001-NEW-ai-inside-init.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
From 28842e6c8ec9125c282373bef0c91196d292ef87 Mon Sep 17 00:00:00 2001
From: "developer" <[email protected]>
Date: Thu, 17 Oct 2024 16:55:38 +0800
Subject: [PATCH] NEW: ai inside init

---
.../yolov8_s_syncbn_fast_8xb16-500e_coco.py | 15 +-
mmyolo/models/backbones/csp_darknet.py | 2 +-
mmyolo/models/dense_heads/yolov8_head.py | 6 +-
mmyolo/models/layers/yolo_bricks.py | 33 ++++-
mmyolo/models/necks/yolov8_pafpn.py | 7 +
mmyolo/utils/deconv_upsampling.py | 134 ++++++++++++++++++
tools/train.py | 6 +
7 files changed, 187 insertions(+), 16 deletions(-)
mode change 100644 => 100755 configs/yolov8/yolov8_s_syncbn_fast_8xb16-500e_coco.py
create mode 100644 mmyolo/utils/deconv_upsampling.py

diff --git a/configs/yolov8/yolov8_s_syncbn_fast_8xb16-500e_coco.py b/configs/yolov8/yolov8_s_syncbn_fast_8xb16-500e_coco.py
old mode 100644
new mode 100755
index 7e4127e..fa07ce5
--- a/configs/yolov8/yolov8_s_syncbn_fast_8xb16-500e_coco.py
+++ b/configs/yolov8/yolov8_s_syncbn_fast_8xb16-500e_coco.py
@@ -2,7 +2,7 @@ _base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']

# ========================Frequently modified parameters======================
# -----data related-----
-data_root = 'data/coco/' # Root path of data
+data_root = '/mlcdev/nnsdk/data/coco/' # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
@@ -99,8 +99,8 @@ model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
- mean=[0., 0., 0.],
- std=[255., 255., 255.],
+ mean=[128., 128., 128.],
+ std=[128., 128., 128.],
bgr_to_rgb=True),
backbone=dict(
type='YOLOv8CSPDarknet',
@@ -109,7 +109,7 @@ model = dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=norm_cfg,
- act_cfg=dict(type='SiLU', inplace=True)),
+ act_cfg=dict(type='ReLU', inplace=True)),
neck=dict(
type='YOLOv8PAFPN',
deepen_factor=deepen_factor,
@@ -118,7 +118,7 @@ model = dict(
out_channels=[256, 512, last_stage_out_channels],
num_csp_blocks=3,
norm_cfg=norm_cfg,
- act_cfg=dict(type='SiLU', inplace=True)),
+ act_cfg=dict(type='ReLU', inplace=True)),
bbox_head=dict(
type='YOLOv8Head',
head_module=dict(
@@ -128,8 +128,9 @@ model = dict(
widen_factor=widen_factor,
reg_max=16,
norm_cfg=norm_cfg,
- act_cfg=dict(type='SiLU', inplace=True),
- featmap_strides=strides),
+ act_cfg=dict(type='ReLU', inplace=True),
+ featmap_strides=strides,
+ skip_dfl=False),
prior_generator=dict(
type='mmdet.MlvlPointGenerator', offset=0.5, strides=strides),
bbox_coder=dict(type='DistancePointBBoxCoder'),
diff --git a/mmyolo/models/backbones/csp_darknet.py b/mmyolo/models/backbones/csp_darknet.py
index 92bd69a..1fd6d98 100644
--- a/mmyolo/models/backbones/csp_darknet.py
+++ b/mmyolo/models/backbones/csp_darknet.py
@@ -281,7 +281,7 @@ class YOLOv8CSPDarknet(BaseBackbone):
spp = SPPFBottleneck(
out_channels,
out_channels,
- kernel_sizes=5,
+ kernel_sizes=3,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(spp)
diff --git a/mmyolo/models/dense_heads/yolov8_head.py b/mmyolo/models/dense_heads/yolov8_head.py
index 2920241..5e9827c 100644
--- a/mmyolo/models/dense_heads/yolov8_head.py
+++ b/mmyolo/models/dense_heads/yolov8_head.py
@@ -54,7 +54,8 @@ class YOLOv8HeadModule(BaseModule):
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
- init_cfg: OptMultiConfig = None):
+ init_cfg: OptMultiConfig = None,
+ skip_dfl: bool = False):
super().__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.featmap_strides = featmap_strides
@@ -64,6 +65,7 @@ class YOLOv8HeadModule(BaseModule):
self.act_cfg = act_cfg
self.in_channels = in_channels
self.reg_max = reg_max
+ self.skip_dfl = skip_dfl

in_channels = []
for channel in self.in_channels:
@@ -162,7 +164,7 @@ class YOLOv8HeadModule(BaseModule):
b, _, h, w = x.shape
cls_logit = cls_pred(x)
bbox_dist_preds = reg_pred(x)
- if self.reg_max > 1:
+ if self.reg_max > 1 and not self.skip_dfl:
bbox_dist_preds = bbox_dist_preds.reshape(
[-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2)

diff --git a/mmyolo/models/layers/yolo_bricks.py b/mmyolo/models/layers/yolo_bricks.py
index 19175be..783a1f8 100644
--- a/mmyolo/models/layers/yolo_bricks.py
+++ b/mmyolo/models/layers/yolo_bricks.py
@@ -111,9 +111,22 @@ class SPPFBottleneck(BaseModule):
if self.conv1:
x = self.conv1(x)
if isinstance(self.kernel_sizes, int):
- y1 = self.poolings(x)
- y2 = self.poolings(y1)
- x = torch.cat([x, y1, y2, self.poolings(y2)], dim=1)
+ if self.kernel_sizes == 5:
+ y1 = self.poolings(x)
+ y2 = self.poolings(y1)
+ x = torch.cat([x, y1, y2, self.poolings(y2)], dim=1)
+ elif self.kernel_sizes == 3:
+ y1 = self.poolings(x)
+ y1 = self.poolings(y1)
+
+ y2 = self.poolings(y1)
+ y2 = self.poolings(y2)
+
+ y3 = self.poolings(y2)
+ y3 = self.poolings(y3)
+ x = torch.cat([x, y1, y2, y3], dim=1)
+ else:
+ print("Not supported for SPPFBottleneck in yolov8")
else:
x = torch.cat(
[x] + [pooling(x) for pooling in self.poolings], dim=1)
@@ -1505,9 +1518,17 @@ class CSPLayerWithTwoConv(BaseModule):
def forward(self, x: Tensor) -> Tensor:
"""Forward process."""
x_main = self.main_conv(x)
- x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
- x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
- return self.final_conv(torch.cat(x_main, 1))
+ chn = x_main.data.shape[1]
+ assert chn % 2 == 0
+ chn = chn // 2
+
+ x_main_half2 = x_main[:, chn:, :, :]
+ tmp = [x_main]
+ for blocks in self.blocks:
+ x_main_half2 = blocks(x_main_half2)
+ tmp.append(x_main_half2)
+
+ return self.final_conv(torch.cat(tmp, 1))


class BiFusion(nn.Module):
diff --git a/mmyolo/models/necks/yolov8_pafpn.py b/mmyolo/models/necks/yolov8_pafpn.py
index e26698b..87ee0b5 100644
--- a/mmyolo/models/necks/yolov8_pafpn.py
+++ b/mmyolo/models/necks/yolov8_pafpn.py
@@ -5,6 +5,7 @@ import torch.nn as nn
from mmdet.utils import ConfigType, OptMultiConfig

from mmyolo.registry import MODELS
+from mmyolo.utils.deconv_upsampling import NearestConvTranspose2d
from .. import CSPLayerWithTwoConv
from ..utils import make_divisible, make_round
from .yolov5_pafpn import YOLOv5PAFPN
@@ -53,6 +54,12 @@ class YOLOv8PAFPN(YOLOv5PAFPN):
act_cfg=act_cfg,
init_cfg=init_cfg)

+ self.upsample_channels = [make_divisible(channel, self.widen_factor) for channel in self.in_channels[:0:-1]]
+ self.upsample_layers = nn.ModuleList()
+ for idx in range(len(self.upsample_channels)):
+ upp = NearestConvTranspose2d(self.upsample_channels[idx], scale_factor=2, with_groups=True)
+ self.upsample_layers.append(upp)
+
def build_reduce_layer(self, idx: int) -> nn.Module:
"""build reduce layer.

diff --git a/mmyolo/utils/deconv_upsampling.py b/mmyolo/utils/deconv_upsampling.py
new file mode 100644
index 0000000..20b75f3
--- /dev/null
+++ b/mmyolo/utils/deconv_upsampling.py
@@ -0,0 +1,134 @@
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+__all__ = ["BilinearConvTranspose2d", "NearestConvTranspose2d"]
+
+
+class BilinearConvTranspose2d(nn.Module):
+ """A conv transpose initialized to bilinear interpolation.
+ this conv tranpose isn't equal to biilinear interpolation mathematically.
+
+ other open source reference: https://gist.github.com/Sundrops/4d4f73f58166b984f5c6bb1723d4e627
+ """
+
+ def __init__(self, channels, scale_factor=2, with_groups=True, freeze_weights=True):
+ super().__init__()
+
+ assert isinstance(
+ scale_factor, int
+ ), f"{type(self).__name__} only supports interger scale factor, while gets {scale_factor}"
+
+ if with_groups:
+ groups = channels
+ else:
+ groups = 1
+
+ ksize = 2 * scale_factor - scale_factor % 2
+ pad = math.ceil((scale_factor - 1) / 2.0)
+ self.upsample = nn.ConvTranspose2d(
+ channels,
+ channels,
+ kernel_size=ksize,
+ stride=scale_factor,
+ padding=pad,
+ groups=groups,
+ bias=False,
+ )
+
+ self.init_weights(freeze_weights)
+
+ def init_weights(self, freeze_weights):
+ for m in self.modules():
+ if isinstance(m, nn.ConvTranspose2d):
+ out_channels, in_channels, kh, kw = m.weight.size()
+ m.weight.data.copy_(
+ self.get_upsampling_weight(in_channels, out_channels, kh)
+ )
+ if freeze_weights:
+ m.weight.requires_grad = False
+
+ @staticmethod
+ def get_upsampling_weight(in_channels, out_channels, kernel_size):
+ assert (in_channels == 1) or (in_channels == out_channels)
+ factor = (kernel_size + 1) // 2
+ if kernel_size % 2 == 1:
+ center = factor - 1
+ else:
+ center = factor - 0.5
+
+ og = np.ogrid[:kernel_size, :kernel_size]
+ filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
+ weight = np.zeros(
+ (out_channels, in_channels, kernel_size, kernel_size), dtype=np.float64
+ )
+ weight[list(range(out_channels)), list(range(in_channels)), :, :] = filt
+ return torch.from_numpy(weight).float()
+
+ def forward(self, x):
+ return self.upsample(x)
+
+class NearestConvTranspose2d(nn.Module):
+ """
+ A ConvTranspose2d for the implementation of the nearest interpolation.
+ """
+ def __init__(self, channels, scale_factor=2, with_groups=False, freeze_weights=True):
+ super().__init__()
+
+ assert isinstance(
+ scale_factor, int
+ ), f"{type(self).__name__} only supports interger scale factor, while gets {scale_factor}"
+ assert scale_factor >= 2, \
+ f"{type(self).__name__} only supports interger scale factor at least 2, while gets {scale_factor}"
+
+ if with_groups:
+ groups = channels
+ else:
+ groups = 1
+
+ kernel_size = scale_factor
+ stride = scale_factor
+ padding = 0 # (kernel_size - stride) / 2
+
+ self.upsample = nn.ConvTranspose2d(
+ channels,
+ channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=False,
+ )
+
+ self.init_weights(freeze_weights, with_groups)
+
+ def init_weights(self, freeze_weights, with_groups):
+ for m in self.modules():
+ if isinstance(m, nn.ConvTranspose2d):
+ out_channels, in_channels, kh, kw = m.weight.size()
+ m.weight.data.copy_(
+ self.get_upsampling_weight(out_channels, in_channels, kh, with_groups)
+ )
+ if freeze_weights:
+ m.weight.requires_grad = False
+
+ @staticmethod
+ def get_upsampling_weight(out_channels, in_channels, kernel_size, with_groups):
+ if with_groups: # depthwise
+ assert in_channels == 1
+ weight = torch.ones(out_channels, 1, kernel_size, kernel_size)
+ # weight[:, :, 1:kernel_size-1, 1:kernel_size-1] = 1
+ else:
+ assert in_channels == out_channels
+ weight = torch.zeros(out_channels, in_channels, kernel_size, kernel_size)
+ for _c_idx in range(out_channels):
+ # weight[_c_idx, _c_idx, 1:kernel_size-1, 1:kernel_size-1] = 1
+ weight[_c_idx, _c_idx, :, :] = 1
+ return weight
+
+ def forward(self, x):
+ return self.upsample(x)
diff --git a/tools/train.py b/tools/train.py
index 61f9498..1ab1f5f 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -4,6 +4,8 @@ import logging
import os
import os.path as osp

+import torch
+
from mmdet.utils import setup_cache_size_limit_of_dynamo
from mmengine.config import Config, DictAction
from mmengine.logging import print_log
@@ -115,6 +117,10 @@ def main():
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)

+ x = torch.rand(1, 3, 320, 320).cuda()
+ torch.onnx.export(runner.model, x, "./yolov8_test.onnx")
+ torch.save(runner.model.state_dict(), "./yolov8_test.pth")
+
# start training
runner.train()

--
2.43.0

Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_base_ = './yolov6_v3_m_syncbn_fast_8xb32-300e_widerface.py'

# ======================= Possible modified parameters =======================
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 1
# The scaling factor that controls the width of the network structure
widen_factor = 1

# ============================== Unmodified in most cases ===================
model = dict(
backbone=dict(
use_cspsppf=False,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(
type='ConvWrapper',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(
type='ConvWrapper',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
block_act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(head_module=dict(reg_max=16, widen_factor=widen_factor)))
Loading