Skip to content

Commit 151e845

Browse files
committed
add test for train_step, val_step, train_batch, val_batch
1 parent c53e408 commit 151e845

File tree

1 file changed

+70
-12
lines changed

1 file changed

+70
-12
lines changed

tests/trainer/test_trainer.py

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.optim import Adam
4+
15
import unittest
26
from unittest.mock import MagicMock
37
from torch.utils.data import DataLoader
48
from onepiece_classify.data import OnepieceImageDataLoader
59
from pathlib import Path
610

7-
from onepiece_classify.trainer import trainer
11+
from onepiece_classify.trainer import Trainer
812
from onepiece_classify.models import image_recog
913

14+
1015
class TestOnepieceImageDataLoader(unittest.TestCase):
1116
def setUp(self):
1217
self.root_path = "data"
@@ -19,18 +24,71 @@ def setUp(self):
1924
self.num_workers
2025
)
2126

22-
self.model = image_recog(num_classes=18)
27+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
28+
self.model = image_recog(num_classes=18).to(self.device)
29+
30+
lrate = 0.001
31+
loss_fn = nn.CrossEntropyLoss()
32+
optimizer = Adam(self.model.parameters(), lr=lrate)
33+
self.trainer = Trainer(
34+
max_epochs = 10,
35+
loss_fn = loss_fn,
36+
optim = optimizer,
37+
learning_rate = lrate
38+
)
2339

24-
self.training = trainer.fit(self.model, self.loader)
40+
self.trainer.model = self.model
41+
self.trainer.loader = self.loader
2542

26-
def test_init(self):
27-
self.assertTrue(isinstance(trainer.train_loss, list))
28-
self.assertTrue(isinstance(trainer.train_acc, list))
29-
self.assertTrue(isinstance(trainer.val_loss, list))
30-
self.assertTrue(isinstance(trainer.val_acc, list))
31-
3243

44+
3345
def test_train_step(self):
34-
a, b = train_step(1, 1)
35-
self.assertTrue(isinstance(a, float))
36-
self.assertTrue(isinstance(b, float))
46+
fake_batch_x = torch.rand(1, 3, 224, 224).to(self.device)
47+
fake_batch_y = torch.randint(0, 18, size=(1,)).to(self.device)
48+
batch = (fake_batch_x, fake_batch_y)
49+
50+
_loss, _acc = self.trainer.train_step(batch)
51+
52+
self.assertEqual(type(_loss), float)
53+
self.assertEqual(type(_acc), float)
54+
self.assertTrue(_loss > 0.)
55+
# self.assertEqual((_acc > 0) and (_acc < 1))
56+
57+
def test_val_step(self):
58+
fake_batch_x = torch.rand(1, 3, 224, 224).to(self.device)
59+
fake_batch_y = torch.randint(0, 18, size=(1,)).to(self.device)
60+
batch = (fake_batch_x, fake_batch_y)
61+
62+
_loss, _acc = self.trainer.train_step(batch)
63+
64+
self.assertEqual(type(_loss), float)
65+
self.assertEqual(type(_acc), float)
66+
self.assertTrue(_loss > 0.)
67+
# self.assertTrue((_acc > 0) and (_acc < 1))
68+
69+
def test_train_batch(self):
70+
_loss, _acc = self.trainer.train_batch()
71+
72+
self.assertEqual(type(_loss), float)
73+
self.assertEqual(type(_acc), float)
74+
self.assertTrue(_loss > 0.)
75+
self.assertTrue((_acc > 0) and (_acc < 1))
76+
77+
def test_val_batch(self):
78+
_loss, _acc = self.trainer.val_batch()
79+
80+
self.assertEqual(type(_loss), float)
81+
self.assertEqual(type(_acc), float)
82+
self.assertTrue(_loss > 0.)
83+
self.assertTrue((_acc > 0) and (_acc < 1))
84+
85+
# def test_init(self):
86+
# self.assertTrue(isinstance(trainer.train_loss, list))
87+
# self.assertTrue(isinstance(trainer.train_acc, list))
88+
# self.assertTrue(isinstance(trainer.val_loss, list))
89+
# self.assertTrue(isinstance(trainer.val_acc, list))
90+
91+
# def test_train_step(self):
92+
# a, b = train_step(1, 1)
93+
# self.assertTrue(isinstance(a, float))
94+
# self.assertTrue(isinstance(b, float))

0 commit comments

Comments
 (0)