Skip to content

Commit 09e9551

Browse files
set up training benchmarking script
1 parent 43e0e9d commit 09e9551

File tree

3 files changed

+194
-0
lines changed

3 files changed

+194
-0
lines changed

benchmarking/Project.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
5+
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
6+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
7+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
8+
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
9+
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
10+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
11+
12+
[compat]
13+
CUDA = "5"
14+
Flux = "0.14"
15+
MLDatasets = "0.7"
16+
Metalhead = "0.9"
17+
Optimisers = "0.3"
18+
ProgressMeter = "1.9"
19+
TimerOutputs = "0.5"
20+
UnicodePlots = "3.6"
21+
cuDNN = "1.2"

benchmarking/benchmark.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
2+
using CUDA, cuDNN
3+
using Flux
4+
using Flux: logitcrossentropy, onecold, onehotbatch
5+
using Metalhead
6+
using MLDatasets
7+
using Optimisers
8+
using ProgressMeter
9+
using TimerOutputs
10+
using UnicodePlots
11+
12+
include("tooling.jl")
13+
14+
epochs = 45
15+
batchsize = 1000
16+
device = gpu
17+
allow_skips = true
18+
19+
train_loader, test_loader, labels = load_cifar10(; batchsize)
20+
nlabels = length(labels)
21+
firstbatch = first(first(train_loader))
22+
imsize = size(firstbatch)[1:2]
23+
24+
to = TimerOutput()
25+
26+
# these should all be the smallest variant of each that is tested in `/test`
27+
modelstrings = (
28+
"AlexNet()",
29+
"VGG(11, batchnorm=true)",
30+
"SqueezeNet()",
31+
"ResNet(18)",
32+
"WideResNet(50)",
33+
"ResNeXt(50, cardinality=32, base_width=4)",
34+
"SEResNet(18)",
35+
"SEResNeXt(50, cardinality=32, base_width=4)",
36+
"Res2Net(50, base_width=26, scale=4)",
37+
"Res2NeXt(50)",
38+
"GoogLeNet(batchnorm=true)",
39+
"DenseNet(121)",
40+
"Inceptionv3()",
41+
"Inceptionv4()",
42+
"InceptionResNetv2()",
43+
"Xception()",
44+
"MobileNetv1(0.5)",
45+
"MobileNetv2(0.5)",
46+
"MobileNetv3(:small, 0.5)",
47+
"MNASNet(MNASNet, 0.5)",
48+
"EfficientNet(:b0)",
49+
"EfficientNetv2(:small)",
50+
"ConvMixer(:small)",
51+
"ConvNeXt(:small)",
52+
# "MLPMixer()", # found no tests
53+
# "ResMLP()", # found no tests
54+
# "gMLP()", # found no tests
55+
"ViT(:tiny)",
56+
"UNet()"
57+
)
58+
59+
for (i, modstring) in enumerate(modelstrings)
60+
@timeit to "$modstring" begin
61+
@info "Evaluating $i/$(length(modelstrings)) $modstring"
62+
@timeit to "First Load" eval(Meta.parse(modstring))
63+
@timeit to "Second Load" model=eval(Meta.parse(modstring))
64+
@timeit to "Training" train(model,
65+
train_loader,
66+
test_loader;
67+
to,
68+
device)||(allow_skips || break)
69+
end
70+
end
71+
print_timer(to; sortby = :firstexec)

benchmarking/tooling.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
function loss_and_accuracy(data_loader, model, device; limit = nothing)
2+
acc = 0
3+
ls = 0.0f0
4+
num = 0
5+
i = 0
6+
for (x, y) in data_loader
7+
x, y = x |> device, y |> device
8+
= model(x)
9+
ls += logitcrossentropy(ŷ, y, agg=sum)
10+
acc += sum(onecold(ŷ) .== onecold(y))
11+
num += size(x)[end]
12+
if limit !== nothing
13+
i == limit && break
14+
i += 1
15+
end
16+
end
17+
return ls / num, acc / num
18+
end
19+
20+
function load_cifar10(; batchsize=1000)
21+
@info "loading CIFAR-10 dataset"
22+
train_dataset, test_dataset = CIFAR10(split=:train), CIFAR10(split=:test)
23+
train_x, train_y = train_dataset[:]
24+
test_x, test_y = test_dataset[:]
25+
@assert train_dataset.metadata["class_names"] == test_dataset.metadata["class_names"]
26+
labels = train_dataset.metadata["class_names"]
27+
28+
# CIFAR10 label indices seem to be zero-indexed
29+
train_y .+= 1
30+
test_y .+= 1
31+
32+
train_y_ohb = Flux.onehotbatch(train_y, eachindex(labels))
33+
test_y_ohb = Flux.onehotbatch(test_y, eachindex(labels))
34+
35+
train_loader = Flux.DataLoader((data=train_x, labels=train_y_ohb); batchsize, shuffle=true)
36+
test_loader = Flux.DataLoader((data=test_x, labels=test_y_ohb); batchsize)
37+
38+
return train_loader, test_loader, labels
39+
end
40+
41+
function _train(model, train_loader, test_loader; epochs = 45, device = gpu, limit=nothing, gpu_gc=true, gpu_stats=false, show_plots=false, to=TimerOutput())
42+
43+
model = model |> device
44+
45+
opt = Optimisers.Adam()
46+
state = Optimisers.setup(opt, model)
47+
48+
train_loss_hist, train_acc_hist = Float64[], Float64[]
49+
test_loss_hist, test_acc_hist = Float64[], Float64[]
50+
51+
@info "starting training"
52+
for epoch in 1:epochs
53+
i = 0
54+
@showprogress "training epoch $epoch/$epochs" for (x, y) in train_loader
55+
x, y = x |> device, y |> device
56+
@timeit to "batch step" begin
57+
gs, _ = gradient(model, x) do m, _x
58+
logitcrossentropy(m(_x), y)
59+
end
60+
state, model = Optimisers.update(state, model, gs)
61+
end
62+
63+
device === gpu && gpu_stats && CUDA.memory_status()
64+
if limit !== nothing
65+
i == limit && break
66+
i += 1
67+
end
68+
end
69+
70+
@info "epoch $epoch complete. Testing..."
71+
train_loss, train_acc = loss_and_accuracy(train_loader, model, device; limit)
72+
@timeit to "testing" test_loss, test_acc = loss_and_accuracy(test_loader, model, device; limit)
73+
@info map(x->round(x, digits=3), (; train_loss, train_acc, test_loss, test_acc))
74+
75+
if show_plots
76+
push!(train_loss_hist, train_loss); push!(train_acc_hist, train_acc);
77+
push!(test_loss_hist, test_loss); push!(test_acc_hist, test_acc);
78+
plt = lineplot(1:epoch, train_loss_hist, name = "train_loss", xlabel="epoch", ylabel="loss")
79+
lineplot!(plt, 1:epoch, test_loss_hist, name = "test_loss")
80+
display(plt)
81+
plt = lineplot(1:epoch, train_acc_hist, name = "train_acc", xlabel="epoch", ylabel="acc")
82+
lineplot!(plt, 1:epoch, test_acc_hist, name = "test_acc")
83+
display(plt)
84+
end
85+
if device === gpu && gpu_gc
86+
GC.gc() # GPU will OOM without this
87+
end
88+
end
89+
end
90+
91+
# because Flux stacktraces are ludicrously big on <1.10 so don't show them
92+
function train(args...;kwargs...)
93+
try
94+
_train(args...; kwargs...)
95+
catch ex
96+
rethrow()
97+
println()
98+
@error sprint(showerror, ex)
99+
GC.gc()
100+
return false
101+
end
102+
end

0 commit comments

Comments
 (0)