1
+ import torch
2
+ import torch .nn as nn
3
+ from torch .optim import Adam
4
+
1
5
import unittest
2
6
from unittest .mock import MagicMock
3
7
from torch .utils .data import DataLoader
4
8
from onepiece_classify .data import OnepieceImageDataLoader
5
9
from pathlib import Path
6
10
7
- from onepiece_classify .trainer import trainer
11
+ from onepiece_classify .trainer import Trainer
8
12
from onepiece_classify .models import image_recog
9
13
14
+
10
15
class TestOnepieceImageDataLoader (unittest .TestCase ):
11
16
def setUp (self ):
12
17
self .root_path = "data"
@@ -19,18 +24,71 @@ def setUp(self):
19
24
self .num_workers
20
25
)
21
26
22
- self .model = image_recog (num_classes = 18 )
27
+ self .device = "cuda" if torch .cuda .is_available () else "cpu"
28
+ self .model = image_recog (num_classes = 18 ).to (self .device )
29
+
30
+ lrate = 0.001
31
+ loss_fn = nn .CrossEntropyLoss ()
32
+ optimizer = Adam (self .model .parameters (), lr = lrate )
33
+ self .trainer = Trainer (
34
+ max_epochs = 10 ,
35
+ loss_fn = loss_fn ,
36
+ optim = optimizer ,
37
+ learning_rate = lrate
38
+ )
23
39
24
- self .training = trainer .fit (self .model , self .loader )
40
+ self .trainer .model = self .model
41
+ self .trainer .loader = self .loader
25
42
26
- def test_init (self ):
27
- self .assertTrue (isinstance (trainer .train_loss , list ))
28
- self .assertTrue (isinstance (trainer .train_acc , list ))
29
- self .assertTrue (isinstance (trainer .val_loss , list ))
30
- self .assertTrue (isinstance (trainer .val_acc , list ))
31
-
32
43
44
+
33
45
def test_train_step (self ):
34
- a , b = train_step (1 , 1 )
35
- self .assertTrue (isinstance (a , float ))
36
- self .assertTrue (isinstance (b , float ))
46
+ fake_batch_x = torch .rand (1 , 3 , 224 , 224 ).to (self .device )
47
+ fake_batch_y = torch .randint (0 , 18 , size = (1 ,)).to (self .device )
48
+ batch = (fake_batch_x , fake_batch_y )
49
+
50
+ _loss , _acc = self .trainer .train_step (batch )
51
+
52
+ self .assertEqual (type (_loss ), float )
53
+ self .assertEqual (type (_acc ), float )
54
+ self .assertTrue (_loss > 0. )
55
+ # self.assertEqual((_acc > 0) and (_acc < 1))
56
+
57
+ def test_val_step (self ):
58
+ fake_batch_x = torch .rand (1 , 3 , 224 , 224 ).to (self .device )
59
+ fake_batch_y = torch .randint (0 , 18 , size = (1 ,)).to (self .device )
60
+ batch = (fake_batch_x , fake_batch_y )
61
+
62
+ _loss , _acc = self .trainer .train_step (batch )
63
+
64
+ self .assertEqual (type (_loss ), float )
65
+ self .assertEqual (type (_acc ), float )
66
+ self .assertTrue (_loss > 0. )
67
+ # self.assertTrue((_acc > 0) and (_acc < 1))
68
+
69
+ def test_train_batch (self ):
70
+ _loss , _acc = self .trainer .train_batch ()
71
+
72
+ self .assertEqual (type (_loss ), float )
73
+ self .assertEqual (type (_acc ), float )
74
+ self .assertTrue (_loss > 0. )
75
+ self .assertTrue ((_acc > 0 ) and (_acc < 1 ))
76
+
77
+ def test_val_batch (self ):
78
+ _loss , _acc = self .trainer .val_batch ()
79
+
80
+ self .assertEqual (type (_loss ), float )
81
+ self .assertEqual (type (_acc ), float )
82
+ self .assertTrue (_loss > 0. )
83
+ self .assertTrue ((_acc > 0 ) and (_acc < 1 ))
84
+
85
+ # def test_init(self):
86
+ # self.assertTrue(isinstance(trainer.train_loss, list))
87
+ # self.assertTrue(isinstance(trainer.train_acc, list))
88
+ # self.assertTrue(isinstance(trainer.val_loss, list))
89
+ # self.assertTrue(isinstance(trainer.val_acc, list))
90
+
91
+ # def test_train_step(self):
92
+ # a, b = train_step(1, 1)
93
+ # self.assertTrue(isinstance(a, float))
94
+ # self.assertTrue(isinstance(b, float))
0 commit comments