|
| 1 | +import torch, torch.nn as nn |
| 2 | +import torch.nn.functional as F |
| 3 | +from torch.autograd import Variable |
| 4 | +from torchvision.models.inception import Inception3 |
| 5 | +from warnings import warn |
| 6 | +from torch.utils.model_zoo import load_url |
| 7 | + |
| 8 | + |
| 9 | +class BeheadedInception3(Inception3): |
| 10 | + """ Like torchvision.models.inception.Inception3 but the head goes separately """ |
| 11 | + |
| 12 | + def forward(self, x): |
| 13 | + if self.transform_input: |
| 14 | + x = x.clone() |
| 15 | + x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 |
| 16 | + x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 |
| 17 | + x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 |
| 18 | + else: warn("Input isn't transformed") |
| 19 | + x = self.Conv2d_1a_3x3(x) |
| 20 | + x = self.Conv2d_2a_3x3(x) |
| 21 | + x = self.Conv2d_2b_3x3(x) |
| 22 | + x = F.max_pool2d(x, kernel_size=3, stride=2) |
| 23 | + x = self.Conv2d_3b_1x1(x) |
| 24 | + x = self.Conv2d_4a_3x3(x) |
| 25 | + x = F.max_pool2d(x, kernel_size=3, stride=2) |
| 26 | + x = self.Mixed_5b(x) |
| 27 | + x = self.Mixed_5c(x) |
| 28 | + x = self.Mixed_5d(x) |
| 29 | + x = self.Mixed_6a(x) |
| 30 | + x = self.Mixed_6b(x) |
| 31 | + x = self.Mixed_6c(x) |
| 32 | + x = self.Mixed_6d(x) |
| 33 | + x = self.Mixed_6e(x) |
| 34 | + x = self.Mixed_7a(x) |
| 35 | + x = self.Mixed_7b(x) |
| 36 | + x_for_attn = x = self.Mixed_7c(x) |
| 37 | + # 8 x 8 x 2048 |
| 38 | + x = F.avg_pool2d(x, kernel_size=8) |
| 39 | + # 1 x 1 x 2048 |
| 40 | + x_for_capt = x = x.view(x.size(0), -1) |
| 41 | + # 2048 |
| 42 | + x = self.fc(x) |
| 43 | + # 1000 (num_classes) |
| 44 | + return x_for_attn, x_for_capt, x |
| 45 | + |
| 46 | + |
| 47 | +def beheaded_inception_v3(transform_input=True): |
| 48 | + model= BeheadedInception3(transform_input=transform_input) |
| 49 | + inception_url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth' |
| 50 | + model.load_state_dict(load_url(inception_url)) |
| 51 | + return model |
| 52 | + |
0 commit comments