@@ -28,7 +28,7 @@ def __init__(
28
28
self ._val_loss = []
29
29
self ._val_acc = []
30
30
31
- def train_step (self , batch , batch_idx ):
31
+ def train_step (self , batch , batch_idx : int = None ):
32
32
X , y = batch
33
33
output = self .model (X )
34
34
@@ -44,7 +44,7 @@ def train_step(self, batch, batch_idx):
44
44
45
45
return train_loss , train_acc
46
46
47
- def val_step (self , batch , batch_idx ):
47
+ def val_step (self , batch , batch_idx : int = None ):
48
48
X , y = batch
49
49
output = self .model (X )
50
50
@@ -56,29 +56,33 @@ def val_step(self, batch, batch_idx):
56
56
57
57
return val_loss , val_acc
58
58
59
- def train_batch (self , epoch ):
59
+ def train_batch (self , epoch : int = None ):
60
60
self .model .train ()
61
61
62
62
self .train_loss = 0
63
63
self .train_acc = 0
64
64
for batch_idx , batch in enumerate (self .loader .train_loader ):
65
- self .train_loss , self .train_acc += self .train_step (batch = batch , batch_idx = batch_idx )
66
-
65
+ _loss , _acc = self .train_step (batch = batch , batch_idx = batch_idx )
66
+ self .train_loss += _loss
67
+ self .train_acc += _acc
68
+
67
69
avg_loss = self .train_loss / len (self .loader .train_loader )
68
70
avg_acc = self .train_acc / len (self .loader .train_loader )
69
71
70
72
return avg_loss , avg_acc
71
73
72
74
73
- def val_batch (self , epoch ):
75
+ def val_batch (self , epoch : int = None ):
74
76
self .model .eval ()
75
77
76
78
self .val_loss = 0
77
79
self .val_acc = 0
78
80
with torch .no_grad ():
79
81
for batch_idx , batch in enumerate (self .loader .valid_loader ):
80
- self .val_loss , self .val_acc += self .val_step (batch = batch , batch_idx = batch_idx )
81
-
82
+ _loss , _acc = self .val_step (batch = batch , batch_idx = batch_idx )
83
+ self .val_loss += _loss
84
+ self .val_acc += _acc
85
+
82
86
avg_loss = self .val_loss / len (self .loader .valid_loader )
83
87
avg_acc = self .val_acc / len (self .loader .valid_loader )
84
88
@@ -104,10 +108,3 @@ def fit(self, model: nn.Module, loader: DataLoader) -> Dict[str, float]:
104
108
self .loader = loader
105
109
self .run ()
106
110
# self.save_model()
107
-
108
- trainer = Trainer (
109
- max_epochs = 5 ,
110
- loss_fn = nn .CrossEntropyLoss (),
111
- optim = torch .optim .Adam ()
112
- learning_rate = 0.001
113
- )
0 commit comments