Skip to content

Commit c53e408

Browse files
committed
edit a few line for test purposes
1 parent 9a4322e commit c53e408

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

src/onepiece_classify/trainer/trainer.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
self._val_loss = []
2929
self._val_acc = []
3030

31-
def train_step(self, batch, batch_idx):
31+
def train_step(self, batch, batch_idx: int = None):
3232
X, y = batch
3333
output = self.model(X)
3434

@@ -44,7 +44,7 @@ def train_step(self, batch, batch_idx):
4444

4545
return train_loss, train_acc
4646

47-
def val_step(self, batch, batch_idx):
47+
def val_step(self, batch, batch_idx: int=None):
4848
X, y = batch
4949
output = self.model(X)
5050

@@ -56,29 +56,33 @@ def val_step(self, batch, batch_idx):
5656

5757
return val_loss, val_acc
5858

59-
def train_batch(self, epoch):
59+
def train_batch(self, epoch: int=None):
6060
self.model.train()
6161

6262
self.train_loss = 0
6363
self.train_acc = 0
6464
for batch_idx, batch in enumerate(self.loader.train_loader):
65-
self.train_loss, self.train_acc += self.train_step(batch=batch, batch_idx=batch_idx)
66-
65+
_loss, _acc = self.train_step(batch=batch, batch_idx=batch_idx)
66+
self.train_loss += _loss
67+
self.train_acc += _acc
68+
6769
avg_loss = self.train_loss/len(self.loader.train_loader)
6870
avg_acc = self.train_acc/len(self.loader.train_loader)
6971

7072
return avg_loss, avg_acc
7173

7274

73-
def val_batch(self, epoch):
75+
def val_batch(self, epoch: int=None):
7476
self.model.eval()
7577

7678
self.val_loss = 0
7779
self.val_acc = 0
7880
with torch.no_grad():
7981
for batch_idx, batch in enumerate(self.loader.valid_loader):
80-
self.val_loss, self.val_acc += self.val_step(batch=batch, batch_idx=batch_idx)
81-
82+
_loss, _acc = self.val_step(batch=batch, batch_idx=batch_idx)
83+
self.val_loss += _loss
84+
self.val_acc += _acc
85+
8286
avg_loss = self.val_loss/len(self.loader.valid_loader)
8387
avg_acc = self.val_acc/len(self.loader.valid_loader)
8488

@@ -104,10 +108,3 @@ def fit(self, model: nn.Module, loader: DataLoader) -> Dict[str, float]:
104108
self.loader = loader
105109
self.run()
106110
# self.save_model()
107-
108-
trainer = Trainer(
109-
max_epochs=5,
110-
loss_fn=nn.CrossEntropyLoss(),
111-
optim=torch.optim.Adam()
112-
learning_rate=0.001
113-
)

0 commit comments

Comments
 (0)