1
+ from operator import le
2
+ import torch
3
+ import torch .nn as nn
4
+ from torch .utils .data import DataLoader
5
+
6
+ import numpy as np
7
+ from typing import Dict
8
+ from tqdm .auto import tqdm
9
+
10
+
11
+ class Trainer :
12
+
13
+ def __init__ (
14
+ self ,
15
+ max_epochs : int ,
16
+ loss_fn : nn .Module ,
17
+ optim : torch .optim .Optimizer ,
18
+ learning_rate : float
19
+ ):
20
+
21
+ self .max_epochs = max_epochs
22
+ self .loss_fn = loss_fn
23
+ self .optim = optim
24
+ self .learning_rate = learning_rate
25
+
26
+ self ._train_loss = []
27
+ self ._train_acc = []
28
+ self ._val_loss = []
29
+ self ._val_acc = []
30
+
31
+ def train_step (self , batch , batch_idx ):
32
+ X , y = batch
33
+ output = self .model (X )
34
+
35
+ loss = self .loss_fn (output , y )
36
+ train_loss = loss .item ()
37
+
38
+ self .optim .zero_grad ()
39
+ loss .backward ()
40
+ self .optim .step ()
41
+
42
+ y_class = torch .argmax (torch .softmax (output , dim = 1 ), dim = 1 )
43
+ train_acc = (y_class == y ).sum ().item ()/ len (y_class )
44
+
45
+ return train_loss , train_acc
46
+
47
+ def val_step (self , batch , batch_idx ):
48
+ X , y = batch
49
+ output = self .model (X )
50
+
51
+ loss = self .loss_fn (output , y )
52
+ val_loss = loss .item ()
53
+
54
+ y_class = torch .argmax (torch .softmax (output , dim = 1 ), dim = 1 )
55
+ val_acc = (y_class == y ).sum ().item ()/ len (y_class )
56
+
57
+ return val_loss , val_acc
58
+
59
+ def train_batch (self , epoch ):
60
+ self .model .train ()
61
+
62
+ self .train_loss = 0
63
+ self .train_acc = 0
64
+ for batch_idx , batch in enumerate (self .loader .train_loader ):
65
+ self .train_loss , self .train_acc += self .train_step (batch = batch , batch_idx = batch_idx )
66
+
67
+ avg_loss = self .train_loss / len (self .loader .train_loader )
68
+ avg_acc = self .train_acc / len (self .loader .train_loader )
69
+
70
+ return avg_loss , avg_acc
71
+
72
+
73
+ def val_batch (self , epoch ):
74
+ self .model .eval ()
75
+
76
+ self .val_loss = 0
77
+ self .val_acc = 0
78
+ with torch .no_grad ():
79
+ for batch_idx , batch in enumerate (self .loader .valid_loader ):
80
+ self .val_loss , self .val_acc += self .val_step (batch = batch , batch_idx = batch_idx )
81
+
82
+ avg_loss = self .val_loss / len (self .loader .valid_loader )
83
+ avg_acc = self .val_acc / len (self .loader .valid_loader )
84
+
85
+ return avg_loss , avg_acc
86
+
87
+ def run (self ):
88
+
89
+ for epoch in tqdm (range (self .max_epochs )):
90
+ train_epoch_loss = self .train_batch (epoch )
91
+ val_epoch_loss = self .val_batch (epoch )
92
+
93
+ self ._train_loss .append (train_epoch_loss )
94
+ self ._val_loss .append (val_epoch_loss )
95
+
96
+ self ._train_acc .append (train_epoch_loss )
97
+ self ._val_acc .append (val_epoch_loss )
98
+
99
+ # def save_model(self, path_to_save):
100
+ # torch.save(self.model.state_dict(), path_to_save)
101
+
102
+ def fit (self , model : nn .Module , loader : DataLoader ) -> Dict [str , float ]:
103
+ self .model = model
104
+ self .loader = loader
105
+ self .run ()
106
+ # self.save_model()
107
+
108
+ trainer = Trainer (
109
+ max_epochs = 5 ,
110
+ loss_fn = nn .CrossEntropyLoss (),
111
+ optim = torch .optim .Adam ()
112
+ learning_rate = 0.001
113
+ )
0 commit comments