-
Notifications
You must be signed in to change notification settings - Fork 36
Make run_ad
return both primal and gradient time
#1002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8d460c0
90127ab
2f82a03
637e4bc
2686fcc
b0216ff
63bb81f
4bbebb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
using Pkg | ||
|
||
using DynamicPPLBenchmarks: Models, make_suite, model_dimension | ||
using BenchmarkTools: @benchmark, median, run | ||
using DynamicPPLBenchmarks: Models, to_backend, make_varinfo | ||
using DynamicPPL.TestUtils.AD: run_ad, NoTest | ||
using Chairmarks: @be | ||
using PrettyTables: PrettyTables, ft_printf | ||
using StableRNGs: StableRNG | ||
using Statistics: median | ||
|
||
rng = StableRNG(23) | ||
|
||
|
@@ -35,48 +35,45 @@ chosen_combinations = [ | |
Models.simple_assume_observe(randn(rng)), | ||
:typed, | ||
:forwarddiff, | ||
false, | ||
), | ||
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), | ||
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), | ||
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), | ||
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), | ||
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), | ||
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), | ||
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), | ||
("Multivariate 1k", multivariate1k, :typed, :mooncake, true), | ||
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), | ||
("Multivariate 10k", multivariate10k, :typed, :mooncake, true), | ||
("Dynamic", Models.dynamic(), :typed, :mooncake, true), | ||
("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true), | ||
("LDA", lda_instance, :typed, :reversediff, true), | ||
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff), | ||
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff), | ||
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff), | ||
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff), | ||
("Smorgasbord", smorgasbord_instance, :typed, :reversediff), | ||
("Smorgasbord", smorgasbord_instance, :typed, :mooncake), | ||
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake), | ||
("Multivariate 1k", multivariate1k, :typed, :mooncake), | ||
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake), | ||
("Multivariate 10k", multivariate10k, :typed, :mooncake), | ||
("Dynamic", Models.dynamic(), :typed, :mooncake), | ||
("Submodel", Models.parent(randn(rng)), :typed, :mooncake), | ||
("LDA", lda_instance, :typed, :reversediff), | ||
Comment on lines
-40
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR also removes the option to set varinfo to be linked or unlinked. Going forward everything is linked. In practice there really aren't any cases where AD is run with unlinked varinfo (indeed running with unlinked is a recipe for bugs when you have constraints like in Dirichlet or LKJCholesky, see e.g. TuringLang/ADTests#7) so I don't think that we should do it here. |
||
] | ||
|
||
# Time running a model-like function that does not use DynamicPPL, as a reference point. | ||
# Eval timings will be relative to this. | ||
reference_time = begin | ||
obs = randn(rng) | ||
median(@benchmark Models.simple_assume_observe_non_model(obs)).time | ||
median(@be Models.simple_assume_observe_non_model(obs)).time | ||
end | ||
|
||
results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[] | ||
results_table = Tuple{String,Int,String,String,Float64,Float64}[] | ||
|
||
for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations | ||
for (model_name, model, varinfo_choice, adbackend) in chosen_combinations | ||
@info "Running benchmark for $model_name" | ||
suite = make_suite(model, varinfo_choice, adbackend, islinked) | ||
results = run(suite) | ||
eval_time = median(results["evaluation"]).time | ||
relative_eval_time = eval_time / reference_time | ||
ad_eval_time = median(results["gradient"]).time | ||
relative_ad_eval_time = ad_eval_time / eval_time | ||
adtype = to_backend(adbackend) | ||
varinfo = make_varinfo(model, varinfo_choice) | ||
ad_result = run_ad(model, adtype; test=NoTest(), benchmark=true, varinfo=varinfo) | ||
relative_eval_time = ad_result.primal_time / reference_time | ||
relative_ad_eval_time = ad_result.grad_time / ad_result.primal_time | ||
push!( | ||
results_table, | ||
( | ||
model_name, | ||
model_dimension(model, islinked), | ||
length(varinfo[:]), | ||
string(adbackend), | ||
string(varinfo_choice), | ||
islinked, | ||
relative_eval_time, | ||
relative_ad_eval_time, | ||
), | ||
|
@@ -89,14 +86,13 @@ header = [ | |
"Dimension", | ||
"AD Backend", | ||
"VarInfo Type", | ||
"Linked", | ||
"Eval Time / Ref Time", | ||
"AD Time / Eval Time", | ||
] | ||
PrettyTables.pretty_table( | ||
table_matrix; | ||
header=header, | ||
tf=PrettyTables.tf_markdown, | ||
formatters=ft_printf("%.1f", [6, 7]), | ||
formatters=ft_printf("%.1f", [5, 6]), | ||
crop=:none, # Always print the whole table, even if it doesn't fit in the terminal. | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,8 @@ | ||
module DynamicPPLBenchmarks | ||
|
||
using DynamicPPL: VarInfo, SimpleVarInfo, VarName | ||
using BenchmarkTools: BenchmarkGroup, @benchmarkable | ||
using DynamicPPL: Model, VarInfo, SimpleVarInfo | ||
using DynamicPPL: DynamicPPL | ||
using ADTypes: ADTypes | ||
using LogDensityProblems: LogDensityProblems | ||
|
||
using ForwardDiff: ForwardDiff | ||
using Mooncake: Mooncake | ||
|
@@ -14,29 +12,14 @@ using StableRNGs: StableRNG | |
include("./Models.jl") | ||
using .Models: Models | ||
|
||
export Models, make_suite, model_dimension | ||
|
||
""" | ||
model_dimension(model, islinked) | ||
|
||
Return the dimension of `model`, accounting for linking, if any. | ||
""" | ||
function model_dimension(model, islinked) | ||
vi = VarInfo() | ||
model(StableRNG(23), vi) | ||
if islinked | ||
vi = DynamicPPL.link(vi, model) | ||
end | ||
return length(vi[:]) | ||
end | ||
export Models, to_backend, make_varinfo | ||
|
||
# Utility functions for representing AD backends using symbols. | ||
# Copied from TuringBenchmarking.jl. | ||
const SYMBOL_TO_BACKEND = Dict( | ||
:forwarddiff => ADTypes.AutoForwardDiff(), | ||
:reversediff => ADTypes.AutoReverseDiff(; compile=false), | ||
:reversediff_compiled => ADTypes.AutoReverseDiff(; compile=true), | ||
:mooncake => ADTypes.AutoMooncake(; config=nothing), | ||
:mooncake => ADTypes.AutoMooncake(), | ||
) | ||
|
||
to_backend(x) = error("Unknown backend: $x") | ||
|
@@ -48,58 +31,37 @@ function to_backend(x::Union{AbstractString,Symbol}) | |
end | ||
|
||
""" | ||
make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) | ||
make_varinfo(model, varinfo_choice::Symbol) | ||
|
||
Create a benchmark suite for `model` using the selected varinfo type and AD backend. | ||
Create a VarInfo for the given `model` using the selected varinfo type. | ||
Available varinfo choices: | ||
• `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` | ||
• `:typed` → uses `DynamicPPL.typed_varinfo(model)` | ||
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` | ||
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) | ||
|
||
The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`). | ||
• `:simple_namedtuple` → builds a `SimpleVarInfo{Float64}(::NamedTuple)` | ||
• `:simple_dict` → builds a `SimpleVarInfo{Float64}(::Dict)` | ||
|
||
`islinked` determines whether to link the VarInfo for evaluation. | ||
The VarInfo is always linked. | ||
""" | ||
function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) | ||
function make_varinfo(model::Model, varinfo_choice::Symbol) | ||
rng = StableRNG(23) | ||
|
||
suite = BenchmarkGroup() | ||
|
||
vi = if varinfo_choice == :untyped | ||
DynamicPPL.untyped_varinfo(rng, model) | ||
elseif varinfo_choice == :typed | ||
DynamicPPL.typed_varinfo(rng, model) | ||
elseif varinfo_choice == :simple_namedtuple | ||
SimpleVarInfo{Float64}(model(rng)) | ||
vi = DynamicPPL.typed_varinfo(rng, model) | ||
vals = DynamicPPL.values_as(vi, NamedTuple) | ||
SimpleVarInfo{Float64}(vals) | ||
elseif varinfo_choice == :simple_dict | ||
retvals = model(rng) | ||
vns = [VarName{k}() for k in keys(retvals)] | ||
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) | ||
vi = DynamicPPL.typed_varinfo(rng, model) | ||
vals = DynamicPPL.values_as(vi, Dict) | ||
SimpleVarInfo{Float64}(vals) | ||
Comment on lines
-76
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The old code only works if the model explicitly returns a NamedTuple with only plain Symbols (which they happen to do). I am aware that this pattern is peppered all over the code base (e.g. with demo models too) but I think we should try to avoid having magic return values and instead rely on functionality that is designed to work on all models. |
||
else | ||
error("Unknown varinfo choice: $varinfo_choice") | ||
end | ||
|
||
adbackend = to_backend(adbackend) | ||
|
||
if islinked | ||
vi = DynamicPPL.link(vi, model) | ||
end | ||
|
||
f = DynamicPPL.LogDensityFunction( | ||
model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend | ||
) | ||
# The parameters at which we evaluate f. | ||
θ = vi[:] | ||
|
||
# Run once to trigger compilation. | ||
LogDensityProblems.logdensity_and_gradient(f, θ) | ||
suite["gradient"] = @benchmarkable $(LogDensityProblems.logdensity_and_gradient)($f, $θ) | ||
|
||
# Also benchmark just standard model evaluation because why not. | ||
suite["evaluation"] = @benchmarkable $(LogDensityProblems.logdensity)($f, $θ) | ||
|
||
return suite | ||
return DynamicPPL.link!!(vi, model) | ||
end | ||
|
||
end # module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note Chairmarks is already a strong dep of DPPL