Skip to content

Commit 28c91b8

Browse files
authored
Merge pull request #10 from lombokai/dev
Merge to main from dev after developing minimun standard of machine learning project
2 parents 45312a1 + d909660 commit 28c91b8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+629
-374
lines changed
File renamed without changes.

.gitignore

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# ignore dataset
22
raw_data/*
3-
src/data/train/*
4-
src/data/val/*
3+
data/train/*
4+
data/val/*
55

66
# Byte-compiled / optimized / DLL files
77
__pycache__/
88
*.py[cod]
99
*$py.class
10+
.vscode
1011

1112
# C extensions
1213
*.so
@@ -125,6 +126,7 @@ celerybeat.pid
125126
*.sage.py
126127

127128
# Environments
129+
.myenv
128130
.env
129131
.venv
130132
env/
@@ -163,3 +165,5 @@ cython_debug/
163165
# and can be added to the global gitignore or merged into this file. For a more nuclear
164166
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
165167
#.idea/
168+
169+
# data/

Makefile

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
run-test:
2+
PYTHONPATH=src python -m pytest -v
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

requirements.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,7 @@ pyarrow
44
seaborn
55
matplotlib
66
jupyter
7-
pytest
7+
pytest
8+
pytest-pythonpath
9+
torch
10+
torchvision

src/build_model.py

-22
This file was deleted.

src/data_setup.py

-34
This file was deleted.

src/engine.py

-150
This file was deleted.

src/inference.py

-10
This file was deleted.

src/onepiece_classify/__init__.py

Whitespace-only changes.
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .setup_data import *
+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from torchvision import datasets
2+
from torch.utils.data import Dataset, DataLoader
3+
from onepiece_classify.transforms import (
4+
get_train_transforms,
5+
get_valid_transforms,
6+
get_test_transforms,
7+
)
8+
from pathlib import Path
9+
10+
11+
class OnepieceImageDataLoader:
12+
def __init__(
13+
self,
14+
root_path: str,
15+
batch_size: int = 32,
16+
num_workers: int = 0
17+
):
18+
self.root_path: Path = Path(root_path)
19+
self.train_path: Path = self.root_path.joinpath("train")
20+
self.valid_path: Path = self.root_path.joinpath("val")
21+
self.test_path: Path = self.root_path.joinpath("test")
22+
23+
self.batch_size = batch_size
24+
self.num_workers = num_workers
25+
26+
self.trainset = self._build_dataset(mode="train")
27+
self.validset = self._build_dataset(mode="valid")
28+
self.testset = self._build_dataset(mode="test")
29+
30+
self.train_loader = self._build_dataloader(mode="train", shuffle=True)
31+
self.valid_loader = self._build_dataloader(mode="valid", shuffle=False)
32+
self.test_loader = self._build_dataloader(mode="test", shuffle=False)
33+
34+
35+
def _build_dataset(self, mode='train') -> datasets.ImageFolder:
36+
dset = None
37+
if mode == "train":
38+
trans = get_train_transforms()
39+
path = self.train_path
40+
elif mode == "valid":
41+
trans = get_valid_transforms()
42+
path = self.valid_path
43+
elif mode == "test":
44+
trans = get_test_transforms()
45+
path = self.test_path
46+
else:
47+
trans = get_test_transforms()
48+
path = self.test_path
49+
50+
dset = datasets.ImageFolder(str(path), transform=trans)
51+
return dset
52+
53+
def _build_dataloader(self, mode: str = "train", shuffle: bool = True) -> DataLoader:
54+
if mode == "train":
55+
dset = self.trainset
56+
elif mode == "valid":
57+
dset = self.validset
58+
elif mode == "test":
59+
dset = self.testset
60+
else:
61+
dset = self.testset
62+
63+
loader = DataLoader(
64+
dset,
65+
batch_size=self.batch_size,
66+
shuffle=shuffle,
67+
pin_memory=True
68+
)
69+
70+
return loader
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .base import *
2+
from .recognition import *

src/onepiece_classify/infer/base.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional
3+
import numpy as np
4+
from PIL import Image
5+
import torch
6+
7+
8+
class BaseInference(ABC):
9+
10+
@abstractmethod
11+
def pre_process(self, image: Optional[str|np.ndarray|Image.Image])->torch.Tensor:
12+
pass
13+
14+
@abstractmethod
15+
def forward(self, image_tensor: torch.Tensor) -> torch.Tensor:
16+
pass
17+
18+
@abstractmethod
19+
def post_process(self, output: torch.Tensor) -> str:
20+
pass
21+
22+
@abstractmethod
23+
def predict(self, image: Optional[str|np.ndarray|Image.Image]) -> dict:
24+
pass

0 commit comments

Comments
 (0)