Skip to content

Commit d909660

Browse files
authored
Merge pull request #9 from lombokai/feature/inference
Feature/inference
2 parents a8b66a2 + 3eb6dbb commit d909660

File tree

14 files changed

+262
-437
lines changed

14 files changed

+262
-437
lines changed

.gitignore

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# ignore dataset
22
raw_data/*
3-
src/data/train/*
4-
src/data/val/*
3+
data/train/*
4+
data/val/*
55

66
# Byte-compiled / optimized / DLL files
77
__pycache__/
88
*.py[cod]
99
*$py.class
10+
.vscode
1011

1112
# C extensions
1213
*.so
@@ -125,6 +126,7 @@ celerybeat.pid
125126
*.sage.py
126127

127128
# Environments
129+
.myenv
128130
.env
129131
.venv
130132
env/
@@ -164,4 +166,4 @@ cython_debug/
164166
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
165167
#.idea/
166168

167-
data/
169+
# data/

src/onepiece_classify/data/data_setup.py

-34
This file was deleted.

src/onepiece_classify/data/setup_data.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from pathlib import Path
99

10+
1011
class OnepieceImageDataLoader:
1112
def __init__(
1213
self,
@@ -67,6 +68,3 @@ def _build_dataloader(self, mode: str = "train", shuffle: bool = True) -> DataLo
6768
)
6869

6970
return loader
70-
71-
72-
+2-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .base import *
1+
from .base import *
2+
from .recognition import *

src/onepiece_classify/infer/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class BaseInference(ABC):
99

1010
@abstractmethod
11-
def pre_process(self, image: Optional[str, np.ndarray, Image])->torch.Tensor:
11+
def pre_process(self, image: Optional[str|np.ndarray|Image.Image])->torch.Tensor:
1212
pass
1313

1414
@abstractmethod
@@ -20,5 +20,5 @@ def post_process(self, output: torch.Tensor) -> str:
2020
pass
2121

2222
@abstractmethod
23-
def predict(self, image: Optional[str, np.ndarray, Image]) -> dict:
24-
pass
23+
def predict(self, image: Optional[str|np.ndarray|Image.Image]) -> dict:
24+
pass

src/onepiece_classify/infer/inference.py

-10
This file was deleted.

src/onepiece_classify/infer/predict.py

-70
This file was deleted.
+80-16
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,92 @@
1+
from PIL import Image
2+
import numpy as np
3+
import torch
4+
from pathlib import Path
5+
from .base import BaseInference
6+
from typing import Optional, Tuple, Dict
17

8+
from onepiece_classify.models import image_recog
9+
from onepiece_classify.transforms import get_test_transforms
210

311

4-
5-
from .base import BaseInference
6-
import torch
7-
812
class ImageRecognition(BaseInference):
913

1014
def __init__(self, model_path: str):
11-
self.model_path = model_path
15+
self.model_path = Path(model_path)
16+
self.class_dict = {
17+
0: 'Ace',
18+
1: 'Akainu',
19+
2: 'Brook',
20+
3: 'Chopper',
21+
4: 'Crocodile',
22+
5: 'Franky',
23+
6: 'Jinbei',
24+
7: 'Kurohige',
25+
8: 'Law',
26+
9: 'Luffy',
27+
10: 'Mihawk',
28+
11: 'Nami',
29+
12: 'Rayleigh',
30+
13: 'Robin',
31+
14: 'Sanji',
32+
15: 'Shanks',
33+
16: 'Usopp',
34+
17: 'Zoro',
35+
}
36+
self.nclass = len(self.class_dict)
37+
self.model = self._build_model()
1238

13-
self.model = self._build_model(model_path)
14-
15-
def _build_model(self, model_path):
39+
def _build_model(self):
1640
# load model
17-
state_dict = torch.load(model_path)
18-
19-
model = ImageModel.load(state_dict)
20-
return model
41+
state_dict = torch.load(self.model_path)
42+
model_backbone = image_recog(self.nclass)
43+
model_backbone.load_state_dict(state_dict)
44+
return model_backbone
2145

46+
def pre_process(self, image: Optional[str | np.ndarray | Image.Image]) -> torch.Tensor:
47+
48+
trans = get_test_transforms()
49+
50+
if isinstance(image, str):
51+
img = Image.open(image).convert("RGB")
52+
img = trans(img).unsqueeze(0)
53+
54+
elif isinstance(image, Image.Image):
55+
img = image.convert("RGB")
56+
img = trans(img).unsqueeze(0)
57+
58+
elif isinstance(image, np.ndarray):
59+
img = image.astype(np.uint8)
60+
img = Image.fromarray(img).convert("RGB")
61+
img = trans(img).unsqueeze(0)
62+
63+
else:
64+
print("Image type not recognized")
65+
66+
return img
67+
2268
def forward(self, image_tensor: torch.Tensor) -> torch.Tensor:
23-
return self.model.forward(image_tensor)
24-
69+
self.model.eval()
70+
71+
result = self.model(image_tensor)
72+
return result
2573

74+
def post_process(self, output: torch.Tensor) -> Tuple[str, float]:
75+
76+
logits_prob = torch.softmax(output, dim=1).squeeze()
77+
class_idx = int(torch.argmax(logits_prob))
78+
79+
class_names = self.class_dict[class_idx]
80+
confidence = logits_prob[class_idx]
81+
return (class_names, float(confidence))
2682

27-
def recognition():
28-
return ImageRecognition(model_path="src/checkpoint/checkpoint_notebook.pth")
83+
def predict(self, image: Optional[str|np.ndarray|Image.Image]) -> Dict[str, str]:
84+
85+
tensor_img = self.pre_process(image=image)
86+
logits = self.forward(tensor_img)
87+
class_names, confidence = self.post_process(logits)
88+
89+
return {
90+
"class_names": class_names,
91+
"confidence": f"{confidence:.4f}"
92+
}

src/onepiece_classify/models/build_model.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,27 @@ class ImageRecogModel(nn.Module):
77

88
def __init__(self, num_classes):
99
super().__init__()
10-
1110
self.num_classes = num_classes
12-
self.backbone = self._build_backbone()
13-
self.in_features = self._build_backbone().classifier[0].in_features
14-
self.backbone.classifier = nn.Sequential(
15-
nn.Dropout(p=0.2),
16-
nn.Linear(self.in_features, out_features=self.num_classes)
17-
)
18-
19-
# self.dropout = nn.Dropout(0.2)
2011

21-
def _build_backbone(self):
12+
def build_backbone(self):
2213
model = models.mobilenet_v3_large(weights="DEFAULT")
2314

2415
for param in model.parameters():
2516
param.requires_grad = False
2617

18+
in_features = model.classifier[0].in_features
19+
model.classifier = nn.Sequential(
20+
nn.Dropout(p=0.2),
21+
nn.Linear(in_features, out_features=self.num_classes)
22+
)
23+
2724
return model
2825

2926
def forward(self, x):
30-
x = self.backbone(x)
27+
model = self.build_backbone()
28+
x = model(x)
3129
return x
3230

3331
def image_recog(num_classes):
34-
net = ImageRecogModel(num_classes)
32+
net = ImageRecogModel(num_classes).build_backbone()
3533
return net

0 commit comments

Comments
 (0)