Skip to content

Commit a8b66a2

Browse files
authored
Merge pull request #8 from lombokai/feature/inference
Feature/inference
2 parents 737f07c + 151e845 commit a8b66a2

File tree

5 files changed

+207
-1
lines changed

5 files changed

+207
-1
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .build_model import *
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .trainer import *
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from operator import le
2+
import torch
3+
import torch.nn as nn
4+
from torch.utils.data import DataLoader
5+
6+
import numpy as np
7+
from typing import Dict
8+
from tqdm.auto import tqdm
9+
10+
11+
class Trainer:
12+
13+
def __init__(
14+
self,
15+
max_epochs: int,
16+
loss_fn: nn.Module,
17+
optim: torch.optim.Optimizer,
18+
learning_rate: float
19+
):
20+
21+
self.max_epochs = max_epochs
22+
self.loss_fn = loss_fn
23+
self.optim = optim
24+
self.learning_rate = learning_rate
25+
26+
self._train_loss = []
27+
self._train_acc = []
28+
self._val_loss = []
29+
self._val_acc = []
30+
31+
def train_step(self, batch, batch_idx: int = None):
32+
X, y = batch
33+
output = self.model(X)
34+
35+
loss = self.loss_fn(output, y)
36+
train_loss = loss.item()
37+
38+
self.optim.zero_grad()
39+
loss.backward()
40+
self.optim.step()
41+
42+
y_class = torch.argmax(torch.softmax(output, dim=1), dim=1)
43+
train_acc = (y_class == y).sum().item()/len(y_class)
44+
45+
return train_loss, train_acc
46+
47+
def val_step(self, batch, batch_idx: int=None):
48+
X, y = batch
49+
output = self.model(X)
50+
51+
loss = self.loss_fn(output, y)
52+
val_loss = loss.item()
53+
54+
y_class = torch.argmax(torch.softmax(output, dim=1), dim=1)
55+
val_acc = (y_class == y).sum().item()/len(y_class)
56+
57+
return val_loss, val_acc
58+
59+
def train_batch(self, epoch: int=None):
60+
self.model.train()
61+
62+
self.train_loss = 0
63+
self.train_acc = 0
64+
for batch_idx, batch in enumerate(self.loader.train_loader):
65+
_loss, _acc = self.train_step(batch=batch, batch_idx=batch_idx)
66+
self.train_loss += _loss
67+
self.train_acc += _acc
68+
69+
avg_loss = self.train_loss/len(self.loader.train_loader)
70+
avg_acc = self.train_acc/len(self.loader.train_loader)
71+
72+
return avg_loss, avg_acc
73+
74+
75+
def val_batch(self, epoch: int=None):
76+
self.model.eval()
77+
78+
self.val_loss = 0
79+
self.val_acc = 0
80+
with torch.no_grad():
81+
for batch_idx, batch in enumerate(self.loader.valid_loader):
82+
_loss, _acc = self.val_step(batch=batch, batch_idx=batch_idx)
83+
self.val_loss += _loss
84+
self.val_acc += _acc
85+
86+
avg_loss = self.val_loss/len(self.loader.valid_loader)
87+
avg_acc = self.val_acc/len(self.loader.valid_loader)
88+
89+
return avg_loss, avg_acc
90+
91+
def run(self):
92+
93+
for epoch in tqdm(range(self.max_epochs)):
94+
train_epoch_loss = self.train_batch(epoch)
95+
val_epoch_loss = self.val_batch(epoch)
96+
97+
self._train_loss.append(train_epoch_loss)
98+
self._val_loss.append(val_epoch_loss)
99+
100+
self._train_acc.append(train_epoch_loss)
101+
self._val_acc.append(val_epoch_loss)
102+
103+
# def save_model(self, path_to_save):
104+
# torch.save(self.model.state_dict(), path_to_save)
105+
106+
def fit(self, model: nn.Module, loader: DataLoader) -> Dict[str, float]:
107+
self.model = model
108+
self.loader = loader
109+
self.run()
110+
# self.save_model()

tests/models/test_image_recog_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from onepiece_classify.models.build_model import image_recog
2+
from onepiece_classify.models import image_recog
33

44
def test_model():
55
rand_tensor = torch.rand([1, 3, 224, 224])

tests/trainer/test_trainer.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.optim import Adam
4+
5+
import unittest
6+
from unittest.mock import MagicMock
7+
from torch.utils.data import DataLoader
8+
from onepiece_classify.data import OnepieceImageDataLoader
9+
from pathlib import Path
10+
11+
from onepiece_classify.trainer import Trainer
12+
from onepiece_classify.models import image_recog
13+
14+
15+
class TestOnepieceImageDataLoader(unittest.TestCase):
16+
def setUp(self):
17+
self.root_path = "data"
18+
self.batch_size = 32
19+
self.num_workers = 4
20+
21+
self.loader = OnepieceImageDataLoader(
22+
self.root_path,
23+
self.batch_size,
24+
self.num_workers
25+
)
26+
27+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
28+
self.model = image_recog(num_classes=18).to(self.device)
29+
30+
lrate = 0.001
31+
loss_fn = nn.CrossEntropyLoss()
32+
optimizer = Adam(self.model.parameters(), lr=lrate)
33+
self.trainer = Trainer(
34+
max_epochs = 10,
35+
loss_fn = loss_fn,
36+
optim = optimizer,
37+
learning_rate = lrate
38+
)
39+
40+
self.trainer.model = self.model
41+
self.trainer.loader = self.loader
42+
43+
44+
45+
def test_train_step(self):
46+
fake_batch_x = torch.rand(1, 3, 224, 224).to(self.device)
47+
fake_batch_y = torch.randint(0, 18, size=(1,)).to(self.device)
48+
batch = (fake_batch_x, fake_batch_y)
49+
50+
_loss, _acc = self.trainer.train_step(batch)
51+
52+
self.assertEqual(type(_loss), float)
53+
self.assertEqual(type(_acc), float)
54+
self.assertTrue(_loss > 0.)
55+
# self.assertEqual((_acc > 0) and (_acc < 1))
56+
57+
def test_val_step(self):
58+
fake_batch_x = torch.rand(1, 3, 224, 224).to(self.device)
59+
fake_batch_y = torch.randint(0, 18, size=(1,)).to(self.device)
60+
batch = (fake_batch_x, fake_batch_y)
61+
62+
_loss, _acc = self.trainer.train_step(batch)
63+
64+
self.assertEqual(type(_loss), float)
65+
self.assertEqual(type(_acc), float)
66+
self.assertTrue(_loss > 0.)
67+
# self.assertTrue((_acc > 0) and (_acc < 1))
68+
69+
def test_train_batch(self):
70+
_loss, _acc = self.trainer.train_batch()
71+
72+
self.assertEqual(type(_loss), float)
73+
self.assertEqual(type(_acc), float)
74+
self.assertTrue(_loss > 0.)
75+
self.assertTrue((_acc > 0) and (_acc < 1))
76+
77+
def test_val_batch(self):
78+
_loss, _acc = self.trainer.val_batch()
79+
80+
self.assertEqual(type(_loss), float)
81+
self.assertEqual(type(_acc), float)
82+
self.assertTrue(_loss > 0.)
83+
self.assertTrue((_acc > 0) and (_acc < 1))
84+
85+
# def test_init(self):
86+
# self.assertTrue(isinstance(trainer.train_loss, list))
87+
# self.assertTrue(isinstance(trainer.train_acc, list))
88+
# self.assertTrue(isinstance(trainer.val_loss, list))
89+
# self.assertTrue(isinstance(trainer.val_acc, list))
90+
91+
# def test_train_step(self):
92+
# a, b = train_step(1, 1)
93+
# self.assertTrue(isinstance(a, float))
94+
# self.assertTrue(isinstance(b, float))

0 commit comments

Comments
 (0)