Skip to content

Commit d84ba65

Browse files
committed
add Trainer class
1 parent a0f61e6 commit d84ba65

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from operator import le
2+
import torch
3+
import torch.nn as nn
4+
from torch.utils.data import DataLoader
5+
6+
import numpy as np
7+
from typing import Dict
8+
from tqdm.auto import tqdm
9+
10+
11+
class Trainer:
12+
13+
def __init__(
14+
self,
15+
max_epochs: int,
16+
loss_fn: nn.Module,
17+
optim: torch.optim.Optimizer,
18+
learning_rate: float
19+
):
20+
21+
self.max_epochs = max_epochs
22+
self.loss_fn = loss_fn
23+
self.optim = optim
24+
self.learning_rate = learning_rate
25+
26+
self._train_loss = []
27+
self._train_acc = []
28+
self._val_loss = []
29+
self._val_acc = []
30+
31+
def train_step(self, batch, batch_idx):
32+
X, y = batch
33+
output = self.model(X)
34+
35+
loss = self.loss_fn(output, y)
36+
train_loss = loss.item()
37+
38+
self.optim.zero_grad()
39+
loss.backward()
40+
self.optim.step()
41+
42+
y_class = torch.argmax(torch.softmax(output, dim=1), dim=1)
43+
train_acc = (y_class == y).sum().item()/len(y_class)
44+
45+
return train_loss, train_acc
46+
47+
def val_step(self, batch, batch_idx):
48+
X, y = batch
49+
output = self.model(X)
50+
51+
loss = self.loss_fn(output, y)
52+
val_loss = loss.item()
53+
54+
y_class = torch.argmax(torch.softmax(output, dim=1), dim=1)
55+
val_acc = (y_class == y).sum().item()/len(y_class)
56+
57+
return val_loss, val_acc
58+
59+
def train_batch(self, epoch):
60+
self.model.train()
61+
62+
self.train_loss = 0
63+
self.train_acc = 0
64+
for batch_idx, batch in enumerate(self.loader.train_loader):
65+
self.train_loss, self.train_acc += self.train_step(batch=batch, batch_idx=batch_idx)
66+
67+
avg_loss = self.train_loss/len(self.loader.train_loader)
68+
avg_acc = self.train_acc/len(self.loader.train_loader)
69+
70+
return avg_loss, avg_acc
71+
72+
73+
def val_batch(self, epoch):
74+
self.model.eval()
75+
76+
self.val_loss = 0
77+
self.val_acc = 0
78+
with torch.no_grad():
79+
for batch_idx, batch in enumerate(self.loader.valid_loader):
80+
self.val_loss, self.val_acc += self.val_step(batch=batch, batch_idx=batch_idx)
81+
82+
avg_loss = self.val_loss/len(self.loader.valid_loader)
83+
avg_acc = self.val_acc/len(self.loader.valid_loader)
84+
85+
return avg_loss, avg_acc
86+
87+
def run(self):
88+
89+
for epoch in tqdm(range(self.max_epochs)):
90+
train_epoch_loss = self.train_batch(epoch)
91+
val_epoch_loss = self.val_batch(epoch)
92+
93+
self._train_loss.append(train_epoch_loss)
94+
self._val_loss.append(val_epoch_loss)
95+
96+
self._train_acc.append(train_epoch_loss)
97+
self._val_acc.append(val_epoch_loss)
98+
99+
# def save_model(self, path_to_save):
100+
# torch.save(self.model.state_dict(), path_to_save)
101+
102+
def fit(self, model: nn.Module, loader: DataLoader) -> Dict[str, float]:
103+
self.model = model
104+
self.loader = loader
105+
self.run()
106+
# self.save_model()
107+
108+
trainer = Trainer(
109+
max_epochs=5,
110+
loss_fn=nn.CrossEntropyLoss(),
111+
optim=torch.optim.Adam()
112+
learning_rate=0.001
113+
)

0 commit comments

Comments
 (0)