From 3fb2325962cc471390f01e73d0a0fdccdc999da6 Mon Sep 17 00:00:00 2001 From: wangshier108 <2310016173@qq.com> Date: Sat, 15 Mar 2025 05:17:07 +0000 Subject: [PATCH 1/2] test dynamic_conv --- configs/train-div2k/train_edsr-baseline-liif.yaml | 5 +++-- models/edsr.py | 8 ++++++-- models/liif.py | 5 ++++- train_liif.py | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/configs/train-div2k/train_edsr-baseline-liif.yaml b/configs/train-div2k/train_edsr-baseline-liif.yaml index 1d34f77..6be98dd 100644 --- a/configs/train-div2k/train_edsr-baseline-liif.yaml +++ b/configs/train-div2k/train_edsr-baseline-liif.yaml @@ -2,7 +2,7 @@ train_dataset: dataset: name: image-folder args: - root_path: ./load/div2k/DIV2K_train_HR + root_path: /data/home/wanghanying/work/dataset/DIV2K/DIV2K_train_HR repeat: 20 cache: in_memory wrapper: @@ -18,7 +18,7 @@ val_dataset: dataset: name: image-folder args: - root_path: ./load/div2k/DIV2K_valid_HR + root_path: /data/home/wanghanying/work/dataset/DIV2K/DIV2K_valid_HR first_k: 10 repeat: 160 cache: in_memory @@ -58,3 +58,4 @@ multi_step_lr: epoch_val: 1 epoch_save: 100 +# resume: /data/home/wanghanying/work/liif/save/train_edsr_dytest/epoch-last.pth diff --git a/models/edsr.py b/models/edsr.py index 965402a..cd6b0f2 100644 --- a/models/edsr.py +++ b/models/edsr.py @@ -8,7 +8,9 @@ import torch.nn.functional as F from models import register - +# from .dynamic_conv import conv3x3 +# from .dynet import DyNet2D +from .dytest import DyNet2D def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( @@ -90,7 +92,9 @@ def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): } class EDSR(nn.Module): - def __init__(self, args, conv=default_conv): + # def __init__(self, args, conv=default_conv): + # def __init__(self, args, conv=conv3x3): + def __init__(self, args, conv=DyNet2D): super(EDSR, self).__init__() self.args = args n_resblocks = args.n_resblocks diff --git a/models/liif.py b/models/liif.py index 11925bb..53c4d06 100644 --- a/models/liif.py +++ b/models/liif.py @@ -16,7 +16,8 @@ def __init__(self, encoder_spec, imnet_spec=None, self.local_ensemble = local_ensemble self.feat_unfold = feat_unfold self.cell_decode = cell_decode - + # self.quant = torch.ao.quantization.QuantStub() + # self.dequant = torch.ao.quantization.DeQuantStub() self.encoder = models.make(encoder_spec) if imnet_spec is not None: @@ -106,5 +107,7 @@ def query_rgb(self, coord, cell=None): return ret def forward(self, inp, coord, cell): + # x = self.quant(x) self.gen_feat(inp) + # x = self.dequant(x) return self.query_rgb(coord, cell) diff --git a/train_liif.py b/train_liif.py index f7f6537..cee828c 100644 --- a/train_liif.py +++ b/train_liif.py @@ -123,7 +123,7 @@ def train(train_loader, model, optimizer): def main(config_, save_path): global config, log, writer config = config_ - log, writer = utils.set_save_path(save_path) + log, writer = utils.set_save_path(save_path, False) with open(os.path.join(save_path, 'config.yaml'), 'w') as f: yaml.dump(config, f, sort_keys=False) From b7c93e3b56be9725bc16522b02de3b3405ded725 Mon Sep 17 00:00:00 2001 From: wangshier108 <2310016173@qq.com> Date: Sat, 15 Mar 2025 05:18:34 +0000 Subject: [PATCH 2/2] dynamic model file --- models/dynamic_conv.py | 342 +++++++++++++++++++++++++++++++++++++++++ models/dynet.py | 102 ++++++++++++ models/dytest.py | 139 +++++++++++++++++ 3 files changed, 583 insertions(+) create mode 100644 models/dynamic_conv.py create mode 100644 models/dynet.py create mode 100644 models/dytest.py diff --git a/models/dynamic_conv.py b/models/dynamic_conv.py new file mode 100644 index 0000000..b896341 --- /dev/null +++ b/models/dynamic_conv.py @@ -0,0 +1,342 @@ +#https://github.com/kaijieshi7/Dynamic-convolution-Pytorch/blob/master/dynamic_conv.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + + + +class attention1d(nn.Module): + def __init__(self, in_planes, ratios, K, temperature, init_weight=True): + super(attention1d, self).__init__() + assert temperature%3==1 + self.avgpool = nn.AdaptiveAvgPool1d(1) + if in_planes!=3: + hidden_planes = int(in_planes*ratios)+1 + else: + hidden_planes = K + self.fc1 = nn.Conv1d(in_planes, hidden_planes, 1, bias=False) + # self.bn = nn.BatchNorm2d(hidden_planes) + self.fc2 = nn.Conv1d(hidden_planes, K, 1, bias=True) + self.temperature = temperature + if init_weight: + self._initialize_weights() + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if isinstance(m ,nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def updata_temperature(self): + if self.temperature!=1: + self.temperature -=3 + print('Change temperature to:', str(self.temperature)) + + + def forward(self, x): + x = self.avgpool(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x).view(x.size(0), -1) + return F.softmax(x/self.temperature, 1) + + +class Dynamic_conv1d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True): + super(Dynamic_conv1d, self).__init__() + assert in_planes%groups==0 + self.in_planes = in_planes + self.out_planes = out_planes + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.bias = bias + self.K = K + self.attention = attention1d(in_planes, ratio, K, temperature) + + self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size), requires_grad=True) + if bias: + self.bias = nn.Parameter(torch.zeros(K, out_planes)) + else: + self.bias = None + if init_weight: + self._initialize_weights() + + #TODO 初始化 + def _initialize_weights(self): + for i in range(self.K): + nn.init.kaiming_uniform_(self.weight[i]) + + + def update_temperature(self): + self.attention.updata_temperature() + + def forward(self, x):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的 + softmax_attention = self.attention(x) + batch_size, in_planes, height = x.size() + x = x.view(1, -1, height, )# 变化成一个维度进行组卷积 + weight = self.weight.view(self.K, -1) + + # 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同) + aggregate_weight = torch.mm(softmax_attention, weight).view(batch_size*self.out_planes, self.in_planes//self.groups, self.kernel_size,) + if self.bias is not None: + aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1) + output = F.conv1d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups*batch_size) + else: + output = F.conv1d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * batch_size) + + output = output.view(batch_size, self.out_planes, output.size(-1)) + return output + + + +class attention2d(nn.Module): + def __init__(self, in_planes, ratios, K, temperature, init_weight=True): + super(attention2d, self).__init__() + assert temperature%3==1 + self.avgpool = nn.AdaptiveAvgPool2d(1) + if in_planes!=3: + hidden_planes = int(in_planes*ratios)+1 + else: + hidden_planes = K + self.fc1 = nn.Conv2d(in_planes, hidden_planes, 1, bias=False) + # self.bn = nn.BatchNorm2d(hidden_planes) + self.fc2 = nn.Conv2d(hidden_planes, K, 1, bias=True) + self.temperature = temperature + if init_weight: + self._initialize_weights() + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if isinstance(m ,nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def updata_temperature(self): + if self.temperature!=1: + self.temperature -=3 + print('Change temperature to:', str(self.temperature)) + + + def forward(self, x): + x = self.avgpool(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x).view(x.size(0), -1) + return F.softmax(x/self.temperature, 1) + + +class Dynamic_conv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True): + super(Dynamic_conv2d, self).__init__() + assert in_planes%groups==0 + self.in_planes = in_planes + self.out_planes = out_planes + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.bias = bias + self.K = K + self.attention = attention2d(in_planes, ratio, K, temperature) + + self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True) + if bias: + self.bias = nn.Parameter(torch.zeros(K, out_planes)) + else: + self.bias = None + if init_weight: + self._initialize_weights() + + #TODO 初始化 + def _initialize_weights(self): + for i in range(self.K): + nn.init.kaiming_uniform_(self.weight[i]) + + + def update_temperature(self): + self.attention.updata_temperature() + + def forward(self, x):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的 + softmax_attention = self.attention(x) + batch_size, in_planes, height, width = x.size() + x = x.view(1, -1, height, width)# 变化成一个维度进行组卷积 + weight = self.weight.view(self.K, -1) + + # 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同) + aggregate_weight = torch.mm(softmax_attention, weight).view(batch_size*self.out_planes, self.in_planes//self.groups, self.kernel_size, self.kernel_size) + if self.bias is not None: + aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1) + output = F.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups*batch_size) + else: + output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * batch_size) + + output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1)) + return output + + +class attention3d(nn.Module): + def __init__(self, in_planes, ratios, K, temperature): + super(attention3d, self).__init__() + assert temperature%3==1 + self.avgpool = nn.AdaptiveAvgPool3d(1) + if in_planes != 3: + hidden_planes = int(in_planes * ratios)+1 + else: + hidden_planes = K + self.fc1 = nn.Conv3d(in_planes, hidden_planes, 1, bias=False) + self.fc2 = nn.Conv3d(hidden_planes, K, 1, bias=False) + self.temperature = temperature + + def updata_temperature(self): + if self.temperature!=1: + self.temperature -=3 + print('Change temperature to:', str(self.temperature)) + + def forward(self, x): + x = self.avgpool(x) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x).view(x.size(0), -1) + return F.softmax(x / self.temperature, 1) + +class Dynamic_conv3d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4, temperature=34): + super(Dynamic_conv3d, self).__init__() + assert in_planes%groups==0 + self.in_planes = in_planes + self.out_planes = out_planes + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.bias = bias + self.K = K + self.attention = attention3d(in_planes, ratio, K, temperature) + + self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size, kernel_size), requires_grad=True) + if bias: + self.bias = nn.Parameter(torch.zeros(K, out_planes)) + else: + self.bias = None + + + #TODO 初始化 + # nn.init.kaiming_uniform_(self.weight, ) + + def update_temperature(self): + self.attention.updata_temperature() + + def forward(self, x):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的 + softmax_attention = self.attention(x) + batch_size, in_planes, depth, height, width = x.size() + x = x.view(1, -1, depth, height, width)# 变化成一个维度进行组卷积 + weight = self.weight.view(self.K, -1) + + # 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同) + aggregate_weight = torch.mm(softmax_attention, weight).view(batch_size*self.out_planes, self.in_planes//self.groups, self.kernel_size, self.kernel_size, self.kernel_size) + if self.bias is not None: + aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1) + output = F.conv3d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups*batch_size) + else: + output = F.conv3d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * batch_size) + + output = output.view(batch_size, self.out_planes, output.size(-3), output.size(-2), output.size(-1)) + return output + + +def conv3x3(in_planes, out_planes, kernel_size=3, bias=False, stride=1, groups=1, dilation=1): + return Dynamic_conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=bias, dilation=dilation) + +# def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): +# return Dynamic_conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) + +# def conv1x1(in_planes, out_planes, stride=1): +# """1x1 convolution""" +# return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + return Dynamic_conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False,) + + +if __name__ == '__main__': + x = torch.randn(24, 3, 20) + model = Dynamic_conv1d(in_planes=3, out_planes=16, kernel_size=3, ratio=0.25, padding=1,) + x = x.to('cuda:0') + model.to('cuda') + # model.attention.cuda() + # nn.Conv3d() + print(model(x).shape) + model.update_temperature() + model.update_temperature() + model.update_temperature() + model.update_temperature() + model.update_temperature() + model.update_temperature() + model.update_temperature() + model.update_temperature() + model.update_temperature() + model.update_temperature() + model.update_temperature() + model.update_temperature() + model.update_temperature() + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + print(model(x).shape) + diff --git a/models/dynet.py b/models/dynet.py new file mode 100644 index 0000000..cf0f78f --- /dev/null +++ b/models/dynet.py @@ -0,0 +1,102 @@ +#### +# 原文: https://0809zheng.github.io/2022/12/20/dyconv.html +#### + +import functools +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.utils import _pair +from torch.nn.parameter import Parameter + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class _coefficient(nn.Module): + def __init__(self, in_channels, num_experts, out_channels, dropout_rate): + super(_coefficient, self).__init__() + self.num_experts = num_experts + self.dropout = nn.Dropout(dropout_rate) + self.fc = nn.Linear(in_channels, num_experts*out_channels) + + def forward(self, x): + x = torch.flatten(x) + x = self.dropout(x) + x = self.fc(x) + x = x.view(self.num_experts, -1) + return torch.softmax(x, dim=0) + + +class DyNet2D(_ConvNd): + r""" + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + num_experts (int): Number of experts per layer + """ + + # def __init__(self, in_channels, out_channels, kernel_size, stride=1, + # padding=0, dilation=1, groups=1, + # bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, + bias=True, padding_mode='reflect', num_experts=3, dropout_rate=0.2): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(DyNet2D, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + False, _pair(0), groups, bias, padding_mode) + + # 全局平均池化 + self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1)) + # 注意力全连接层 + self._coefficient_fn = _coefficient(in_channels, num_experts, out_channels, dropout_rate) + # 多套卷积层的权重 + self.weight = Parameter(torch.Tensor( + num_experts, out_channels, in_channels // groups, *kernel_size)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=nn.init.calculate_gain('relu')) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _conv_forward(self, input, weight): + if self.padding_mode != 'zeros': + # print("why padding_mode != zeros") + return F.conv2d(F.pad(input, (1,1,1,1), mode=self.padding_mode), + weight, self.bias, self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, inputs): # [b, c, h, w] + res = [] + for input in inputs: + # import pdb; pdb.set_trace() + # print("why input: ", input.shape) + input = input.unsqueeze(0) # [1, c, h, w] + pooled_inputs = self._avg_pooling(input) # [1, c, 1, 1] + # print("why pooled_inputs: ", pooled_inputs.shape) #torch.Size([1, 64, 1, 1]) + routing_weights = self._coefficient_fn(pooled_inputs) # [k,] #torch.Size([3, 64]) + # print("why routing_weights: ", routing_weights.shape) + kernels = torch.sum(routing_weights[: , :, None, None, None] * self.weight, 0) + # print("why kernels: ", kernels.shape) + out = self._conv_forward(input, kernels) + # print("why out: ", out.shape) + res.append(out) + return torch.cat(res, dim=0) + + # 太慢了这个for + \ No newline at end of file diff --git a/models/dytest.py b/models/dytest.py new file mode 100644 index 0000000..4789567 --- /dev/null +++ b/models/dytest.py @@ -0,0 +1,139 @@ +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from torch.nn.modules.utils import _pair + + + +# class DyNet2D(nn.Module): +# def __init__(self, in_channels, out_channels, kernel_size, stride=1, +# padding=0, dilation=1, groups=1, +# bias=True, padding_mode='reflect', num_experts=3, dropout_rate=0.2): +# super(DyNet2D, self).__init__() +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.kernel_size = _pair(kernel_size) +# self.stride = _pair(stride) +# self.padding = _pair(padding) +# self.dilation = _pair(dilation) +# self.groups = groups +# self.bias = bias +# self.padding_mode = padding_mode +# self.num_experts = num_experts +# self.dropout_rate = dropout_rate + +# # 全局平均池化 +# self._avg_pooling = lambda x: F.adaptive_avg_pool2d(x, output_size=(1, 1)) + +# # 注意力全连接层 +# self._coefficient_fn = self._create_coefficient_fn(in_channels, num_experts, dropout_rate) + +# # 多套卷积层的权重 +# self.weight = nn.Parameter(torch.Tensor(num_experts, out_channels, in_channels // groups, *self.kernel_size)) +# if bias: +# self.bias = nn.Parameter(torch.Tensor(out_channels)) +# else: +# self.register_parameter('bias', None) + +# self.reset_parameters() + +# def _create_coefficient_fn(self, in_channels, num_experts, dropout_rate): +# return nn.Sequential( +# nn.Linear(in_channels, num_experts), +# nn.Softmax(dim=1) +# ) + +# def reset_parameters(self): +# nn.init.kaiming_uniform_(self.weight, a=nn.init.calculate_gain('relu')) +# if self.bias is not None: +# nn.init.zeros_(self.bias) + +# def _conv_forward(self, input, weight): +# if self.padding_mode != 'zeros': +# return F.conv2d(F.pad(input, (1, 1, 1, 1), mode=self.padding_mode), +# weight, self.bias, self.stride, +# _pair(0), self.dilation, self.groups) +# return F.conv2d(input, weight, self.bias, self.stride, +# self.padding, self.dilation, self.groups) + +# def forward(self, inputs): # [b, c, h, w] +# # 全局平均池化 +# pooled_inputs = self._avg_pooling(inputs) # [b, c, 1, 1] +# pooled_inputs = pooled_inputs.view(-1, self.in_channels) # [b, c] + +# # 注意力权重计算 +# routing_weights = self._coefficient_fn(pooled_inputs) # [b, num_experts] +# import pdb; pdb.set_trace() + +# # 卷积核加权求和 +# kernels = torch.einsum('bn, nchw->bchw', routing_weights, self.weight) + +# # 批量卷积操作 +# b, c, h, w = inputs.shape +# outputs = [] +# for i in range(b): +# out = self._conv_forward(inputs[i].unsqueeze(0), kernels[i]) +# outputs.append(out) +# return torch.cat(outputs, dim=0) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair + + +class DyNet2D(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, + bias=True, padding_mode='reflect', num_experts=3, dropout_rate=0.2): + super(DyNet2D, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.num_experts = num_experts + + # 全局平均池化 + self._avg_pooling = lambda x: F.adaptive_avg_pool2d(x, output_size=(1, 1)) + + # 注意力计算 + self._coefficient_fn = nn.Sequential( + nn.Linear(in_channels, num_experts), + nn.Softmax(dim=1) + ) + + # 动态卷积权重 + self.weight = nn.Parameter(torch.Tensor(num_experts, out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=nn.init.calculate_gain('relu')) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, inputs): # [b, c, h, w] + b, c, h, w = inputs.shape + + # **1. 计算注意力权重** + pooled_inputs = self._avg_pooling(inputs) # [b, c, 1, 1] torch.Size([16, 3, 1, 1]) + # import pdb; pdb.set_trace() + pooled_inputs = pooled_inputs.view(b, c) # [b, c] torch.Size([16, 3]) + routing_weights = self._coefficient_fn(pooled_inputs) # [b, num_experts] torch.Size([16, 3]) + + # **2. 计算动态卷积核** + kernels = torch.einsum('bn, nochw -> bochw', routing_weights, self.weight) # [b, out_channels, in_channels, kH, kW] torch.Size([16, 64, 3, 3, 3]) + + # **3. 进行批量卷积** + inputs = inputs.reshape(1, b * c, h, w) # 变换成 `[1, batch * c, h, w]` torch.Size([1, 48, 48, 48]) + kernels = kernels.reshape(b * self.out_channels, self.in_channels // self.groups, *self.kernel_size) # `[b*out_c, in_c, kH, kW]` torch.Size([1024, 3, 3, 3]) + outputs = F.conv2d(F.pad(inputs, (1,1,1,1), mode="reflect"), kernels, self.bias.repeat(b), self.stride, self.padding, self.dilation, self.groups * b) + + return outputs.reshape(b, self.out_channels, outputs.shape[-2], outputs.shape[-1])