Skip to content

Commit e210b90

Browse files
committed
add trainer tests
1 parent d84ba65 commit e210b90

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/trainer/test_trainer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
from unittest.mock import MagicMock
3+
from torch.utils.data import DataLoader
4+
from onepiece_classify.data import OnepieceImageDataLoader
5+
from pathlib import Path
6+
7+
from onepiece_classify.trainer import trainer
8+
from onepiece_classify.models import image_recog
9+
10+
class TestOnepieceImageDataLoader(unittest.TestCase):
11+
def setUp(self):
12+
self.root_path = "data"
13+
self.batch_size = 32
14+
self.num_workers = 4
15+
16+
self.loader = OnepieceImageDataLoader(
17+
self.root_path,
18+
self.batch_size,
19+
self.num_workers
20+
)
21+
22+
self.model = image_recog(num_classes=18)
23+
24+
self.training = trainer.fit(self.model, self.loader)
25+
26+
def test_init(self):
27+
self.assertTrue(isinstance(trainer.train_loss, list))
28+
self.assertTrue(isinstance(trainer.train_acc, list))
29+
self.assertTrue(isinstance(trainer.val_loss, list))
30+
self.assertTrue(isinstance(trainer.val_acc, list))
31+
32+
33+
def test_train_step(self):
34+
a, b = train_step(1, 1)
35+
self.assertTrue(isinstance(a, float))
36+
self.assertTrue(isinstance(b, float))

0 commit comments

Comments
 (0)