Skip to content

Make run_ad return both primal and gradient time #1002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ Please see the API documentation for more details.

There is now also an `rng` keyword argument to help seed parameter generation.

Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
Instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`.

Finally, the `ADResult` object returned by `run_ad` now has both `grad_time` and `primal_time` fields, which contain the time it took to calculate the gradient of logp and logp itself.
Previously there was only a single `time_vs_primal` field which represented the ratio of these two.

### `DynamicPPL.TestUtils.check_model`

You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`.
Expand Down
16 changes: 8 additions & 8 deletions benchmarks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,28 @@ version = "0.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note Chairmarks is already a strong dep of DPPL

Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[sources]
DynamicPPL = {path = "../"}

[compat]
ADTypes = "1.14.0"
BenchmarkTools = "1.6.0"
Distributions = "0.25.117"
ADTypes = "1"
Chairmarks = "1"
Distributions = "0.25"
DynamicPPL = "0.37"
ForwardDiff = "0.10.38, 1"
LogDensityProblems = "2.1.2"
Mooncake = "0.4"
PrettyTables = "2.4.0"
ReverseDiff = "1.15.3"
PrettyTables = "2"
ReverseDiff = "1"
StableRNGs = "1"
Statistics = "1"
58 changes: 27 additions & 31 deletions benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using Pkg

using DynamicPPLBenchmarks: Models, make_suite, model_dimension
using BenchmarkTools: @benchmark, median, run
using DynamicPPLBenchmarks: Models, to_backend, make_varinfo
using DynamicPPL.TestUtils.AD: run_ad, NoTest
using Chairmarks: @be
using PrettyTables: PrettyTables, ft_printf
using StableRNGs: StableRNG
using Statistics: median

rng = StableRNG(23)

Expand Down Expand Up @@ -35,48 +35,45 @@ chosen_combinations = [
Models.simple_assume_observe(randn(rng)),
:typed,
:forwarddiff,
false,
),
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false),
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true),
("Multivariate 1k", multivariate1k, :typed, :mooncake, true),
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true),
("Multivariate 10k", multivariate10k, :typed, :mooncake, true),
("Dynamic", Models.dynamic(), :typed, :mooncake, true),
("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true),
("LDA", lda_instance, :typed, :reversediff, true),
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff),
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff),
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff),
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff),
("Smorgasbord", smorgasbord_instance, :typed, :reversediff),
("Smorgasbord", smorgasbord_instance, :typed, :mooncake),
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake),
("Multivariate 1k", multivariate1k, :typed, :mooncake),
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake),
("Multivariate 10k", multivariate10k, :typed, :mooncake),
("Dynamic", Models.dynamic(), :typed, :mooncake),
("Submodel", Models.parent(randn(rng)), :typed, :mooncake),
("LDA", lda_instance, :typed, :reversediff),
Comment on lines -40 to +50
Copy link
Member Author

@penelopeysm penelopeysm Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR also removes the option to set varinfo to be linked or unlinked. Going forward everything is linked. In practice there really aren't any cases where AD is run with unlinked varinfo (indeed running with unlinked is a recipe for bugs when you have constraints like in Dirichlet or LKJCholesky, see e.g. TuringLang/ADTests#7) so I don't think that we should do it here.

]

# Time running a model-like function that does not use DynamicPPL, as a reference point.
# Eval timings will be relative to this.
reference_time = begin
obs = randn(rng)
median(@benchmark Models.simple_assume_observe_non_model(obs)).time
median(@be Models.simple_assume_observe_non_model(obs)).time
end

results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[]
results_table = Tuple{String,Int,String,String,Float64,Float64}[]

for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
for (model_name, model, varinfo_choice, adbackend) in chosen_combinations
@info "Running benchmark for $model_name"
suite = make_suite(model, varinfo_choice, adbackend, islinked)
results = run(suite)
eval_time = median(results["evaluation"]).time
relative_eval_time = eval_time / reference_time
ad_eval_time = median(results["gradient"]).time
relative_ad_eval_time = ad_eval_time / eval_time
adtype = to_backend(adbackend)
varinfo = make_varinfo(model, varinfo_choice)
ad_result = run_ad(model, adtype; test=NoTest(), benchmark=true, varinfo=varinfo)
relative_eval_time = ad_result.primal_time / reference_time
relative_ad_eval_time = ad_result.grad_time / ad_result.primal_time
push!(
results_table,
(
model_name,
model_dimension(model, islinked),
length(varinfo[:]),
string(adbackend),
string(varinfo_choice),
islinked,
relative_eval_time,
relative_ad_eval_time,
),
Expand All @@ -89,14 +86,13 @@ header = [
"Dimension",
"AD Backend",
"VarInfo Type",
"Linked",
"Eval Time / Ref Time",
"AD Time / Eval Time",
]
PrettyTables.pretty_table(
table_matrix;
header=header,
tf=PrettyTables.tf_markdown,
formatters=ft_printf("%.1f", [6, 7]),
formatters=ft_printf("%.1f", [5, 6]),
crop=:none, # Always print the whole table, even if it doesn't fit in the terminal.
)
70 changes: 16 additions & 54 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
module DynamicPPLBenchmarks

using DynamicPPL: VarInfo, SimpleVarInfo, VarName
using BenchmarkTools: BenchmarkGroup, @benchmarkable
using DynamicPPL: Model, VarInfo, SimpleVarInfo
using DynamicPPL: DynamicPPL
using ADTypes: ADTypes
using LogDensityProblems: LogDensityProblems

using ForwardDiff: ForwardDiff
using Mooncake: Mooncake
Expand All @@ -14,29 +12,14 @@ using StableRNGs: StableRNG
include("./Models.jl")
using .Models: Models

export Models, make_suite, model_dimension

"""
model_dimension(model, islinked)

Return the dimension of `model`, accounting for linking, if any.
"""
function model_dimension(model, islinked)
vi = VarInfo()
model(StableRNG(23), vi)
if islinked
vi = DynamicPPL.link(vi, model)
end
return length(vi[:])
end
export Models, to_backend, make_varinfo

# Utility functions for representing AD backends using symbols.
# Copied from TuringBenchmarking.jl.
const SYMBOL_TO_BACKEND = Dict(
:forwarddiff => ADTypes.AutoForwardDiff(),
:reversediff => ADTypes.AutoReverseDiff(; compile=false),
:reversediff_compiled => ADTypes.AutoReverseDiff(; compile=true),
:mooncake => ADTypes.AutoMooncake(; config=nothing),
:mooncake => ADTypes.AutoMooncake(),
)

to_backend(x) = error("Unknown backend: $x")
Expand All @@ -48,58 +31,37 @@ function to_backend(x::Union{AbstractString,Symbol})
end

"""
make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
make_varinfo(model, varinfo_choice::Symbol)

Create a benchmark suite for `model` using the selected varinfo type and AD backend.
Create a VarInfo for the given `model` using the selected varinfo type.
Available varinfo choices:
• `:untyped` → uses `DynamicPPL.untyped_varinfo(model)`
• `:typed` → uses `DynamicPPL.typed_varinfo(model)`
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)

The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`).
• `:simple_namedtuple` → builds a `SimpleVarInfo{Float64}(::NamedTuple)`
• `:simple_dict` → builds a `SimpleVarInfo{Float64}(::Dict)`

`islinked` determines whether to link the VarInfo for evaluation.
The VarInfo is always linked.
"""
function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
function make_varinfo(model::Model, varinfo_choice::Symbol)
rng = StableRNG(23)

suite = BenchmarkGroup()

vi = if varinfo_choice == :untyped
DynamicPPL.untyped_varinfo(rng, model)
elseif varinfo_choice == :typed
DynamicPPL.typed_varinfo(rng, model)
elseif varinfo_choice == :simple_namedtuple
SimpleVarInfo{Float64}(model(rng))
vi = DynamicPPL.typed_varinfo(rng, model)
vals = DynamicPPL.values_as(vi, NamedTuple)
SimpleVarInfo{Float64}(vals)
elseif varinfo_choice == :simple_dict
retvals = model(rng)
vns = [VarName{k}() for k in keys(retvals)]
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
vi = DynamicPPL.typed_varinfo(rng, model)
vals = DynamicPPL.values_as(vi, Dict)
SimpleVarInfo{Float64}(vals)
Comment on lines -76 to +59
Copy link
Member Author

@penelopeysm penelopeysm Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code only works if the model explicitly returns a NamedTuple with only plain Symbols (which they happen to do). I am aware that this pattern is peppered all over the code base (e.g. with demo models too) but I think we should try to avoid having magic return values and instead rely on functionality that is designed to work on all models.

else
error("Unknown varinfo choice: $varinfo_choice")
end

adbackend = to_backend(adbackend)

if islinked
vi = DynamicPPL.link(vi, model)
end

f = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend
)
# The parameters at which we evaluate f.
θ = vi[:]

# Run once to trigger compilation.
LogDensityProblems.logdensity_and_gradient(f, θ)
suite["gradient"] = @benchmarkable $(LogDensityProblems.logdensity_and_gradient)($f, $θ)

# Also benchmark just standard model evaluation because why not.
suite["evaluation"] = @benchmarkable $(LogDensityProblems.logdensity)($f, $θ)

return suite
return DynamicPPL.link!!(vi, model)
end

end # module
41 changes: 33 additions & 8 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,11 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloa
value_actual::Tresult
"The gradient of logp (calculated using `adtype`)"
grad_actual::Vector{Tresult}
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
time_vs_primal::Union{Nothing,Tresult}
"If benchmarking was requested, the time taken by the AD backend to evaluate the gradient
of logp"
grad_time::Union{Nothing,Tresult}
"If benchmarking was requested, the time taken by the AD backend to evaluate logp"
primal_time::Union{Nothing,Tresult}
end

"""
Expand All @@ -121,6 +124,8 @@ end
benchmark=false,
atol::AbstractFloat=1e-8,
rtol::AbstractFloat=sqrt(eps()),
getlogdensity::Function=getlogjoint_internal,
rng::AbstractRNG=default_rng(),
varinfo::AbstractVarInfo=link(VarInfo(model), model),
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
verbose=true,
Expand Down Expand Up @@ -174,6 +179,21 @@ Everything else is optional, and can be categorised into several groups:
prep_params)`. You could then evaluate the gradient at a different set of
parameters using the `params` keyword argument.

3. _Which type of logp is being calculated._

By default, `run_ad` evaluates the 'internal log joint density' of the model,
i.e., the log joint density in the unconstrained space. Thus, for example, in

@model f() = x ~ LogNormal()

the internal log joint density is `logpdf(Normal(), log(x))`. This is the
relevant log density for e.g. Hamiltonian Monte Carlo samplers and is therefore
the most useful to test.

If you want the log joint density in the original model parameterisation, you
can use `getlogjoint`. Likewise, if you want only the prior or likelihood,
you can use `getlogprior` or `getloglikelihood`, respectively.

3. _How to specify the results to compare against._

Once logp and its gradient has been calculated with the specified `adtype`,
Expand Down Expand Up @@ -277,14 +297,18 @@ function run_ad(
end

# Benchmark
time_vs_primal = if benchmark
grad_time, primal_time = if benchmark
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
t = median(grad_benchmark).time / median(primal_benchmark).time
verbose && println("grad / primal : $(t)")
t
median_primal = median(primal_benchmark).time
median_grad = median(grad_benchmark).time
r(f) = round(f; sigdigits=4)
verbose && println(
"grad / primal : $(r(median_grad))/$(r(median_primal)) = $(r(median_grad / median_primal))",
)
(median_grad, median_primal)
else
nothing
nothing, nothing
end

return ADResult(
Expand All @@ -299,7 +323,8 @@ function run_ad(
grad_true,
value,
grad,
time_vs_primal,
grad_time,
primal_time,
)
end

Expand Down
Loading