1
+ function loss_and_accuracy (data_loader, model, device; limit = nothing )
2
+ acc = 0
3
+ ls = 0.0f0
4
+ num = 0
5
+ i = 0
6
+ for (x, y) in data_loader
7
+ x, y = x |> device, y |> device
8
+ ŷ = model (x)
9
+ ls += logitcrossentropy (ŷ, y, agg= sum)
10
+ acc += sum (onecold (ŷ) .== onecold (y))
11
+ num += size (x)[end ]
12
+ if limit != = nothing
13
+ i == limit && break
14
+ i += 1
15
+ end
16
+ end
17
+ return ls / num, acc / num
18
+ end
19
+
20
+ function load_cifar10 (; batchsize= 1000 )
21
+ @info " loading CIFAR-10 dataset"
22
+ train_dataset, test_dataset = CIFAR10 (split= :train ), CIFAR10 (split= :test )
23
+ train_x, train_y = train_dataset[:]
24
+ test_x, test_y = test_dataset[:]
25
+ @assert train_dataset. metadata[" class_names" ] == test_dataset. metadata[" class_names" ]
26
+ labels = train_dataset. metadata[" class_names" ]
27
+
28
+ # CIFAR10 label indices seem to be zero-indexed
29
+ train_y .+ = 1
30
+ test_y .+ = 1
31
+
32
+ train_y_ohb = Flux. onehotbatch (train_y, eachindex (labels))
33
+ test_y_ohb = Flux. onehotbatch (test_y, eachindex (labels))
34
+
35
+ train_loader = Flux. DataLoader ((data= train_x, labels= train_y_ohb); batchsize, shuffle= true )
36
+ test_loader = Flux. DataLoader ((data= test_x, labels= test_y_ohb); batchsize)
37
+
38
+ return train_loader, test_loader, labels
39
+ end
40
+
41
+ function _train (model, train_loader, test_loader; epochs = 45 , device = gpu, limit= nothing , gpu_gc= true , gpu_stats= false , show_plots= false , to= TimerOutput ())
42
+
43
+ model = model |> device
44
+
45
+ opt = Optimisers. Adam ()
46
+ state = Optimisers. setup (opt, model)
47
+
48
+ train_loss_hist, train_acc_hist = Float64[], Float64[]
49
+ test_loss_hist, test_acc_hist = Float64[], Float64[]
50
+
51
+ @info " starting training"
52
+ for epoch in 1 : epochs
53
+ i = 0
54
+ @showprogress " training epoch $epoch /$epochs " for (x, y) in train_loader
55
+ x, y = x |> device, y |> device
56
+ @timeit to " batch step" begin
57
+ gs, _ = gradient (model, x) do m, _x
58
+ logitcrossentropy (m (_x), y)
59
+ end
60
+ state, model = Optimisers. update (state, model, gs)
61
+ end
62
+
63
+ device === gpu && gpu_stats && CUDA. memory_status ()
64
+ if limit != = nothing
65
+ i == limit && break
66
+ i += 1
67
+ end
68
+ end
69
+
70
+ @info " epoch $epoch complete. Testing..."
71
+ train_loss, train_acc = loss_and_accuracy (train_loader, model, device; limit)
72
+ @timeit to " testing" test_loss, test_acc = loss_and_accuracy (test_loader, model, device; limit)
73
+ @info map (x-> round (x, digits= 3 ), (; train_loss, train_acc, test_loss, test_acc))
74
+
75
+ if show_plots
76
+ push! (train_loss_hist, train_loss); push! (train_acc_hist, train_acc);
77
+ push! (test_loss_hist, test_loss); push! (test_acc_hist, test_acc);
78
+ plt = lineplot (1 : epoch, train_loss_hist, name = " train_loss" , xlabel= " epoch" , ylabel= " loss" )
79
+ lineplot! (plt, 1 : epoch, test_loss_hist, name = " test_loss" )
80
+ display (plt)
81
+ plt = lineplot (1 : epoch, train_acc_hist, name = " train_acc" , xlabel= " epoch" , ylabel= " acc" )
82
+ lineplot! (plt, 1 : epoch, test_acc_hist, name = " test_acc" )
83
+ display (plt)
84
+ end
85
+ if device === gpu && gpu_gc
86
+ GC. gc () # GPU will OOM without this
87
+ end
88
+ end
89
+ end
90
+
91
+ # because Flux stacktraces are ludicrously big on <1.10 so don't show them
92
+ function train (args... ;kwargs... )
93
+ try
94
+ _train (args... ; kwargs... )
95
+ catch ex
96
+ rethrow ()
97
+ println ()
98
+ @error sprint (showerror, ex)
99
+ GC. gc ()
100
+ return false
101
+ end
102
+ end
0 commit comments