diff --git a/idlmam.py b/idlmam.py index ed8685e..64cc772 100644 --- a/idlmam.py +++ b/idlmam.py @@ -120,7 +120,7 @@ def train_simple_network(model, loss_func, train_loader, test_loader=None, score results[item] = [] #SGD is Stochastic Gradient Decent. - optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + optimizer = torch.optim.SGD(model.parameters(), lr=lr) #Place the model on the correct compute resource (CPU or GPU) model.to(device) for epoch in tqdm(range(epochs), desc="Epoch"):