diff --git a/Project.toml b/Project.toml index a079ce2..54188c5 100644 --- a/Project.toml +++ b/Project.toml @@ -6,9 +6,11 @@ version = "0.1.2" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" JuliaVariables = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -17,6 +19,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14" MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b" +MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" NestedTuples = "a734d2a7-8d68-409b-9419-626914d4061d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -38,17 +41,20 @@ TupleVectors = "615932cf-77b6-4358-adcd-5b7eba981d7e" [compat] Accessors = "0.1" -ArrayInterface = "4, 5, 6" +ArrayInterface = "5, 6" +ChainRulesCore = "1" DataStructures = "0.18" DensityInterface = "0.4" DiffResults = "1" +Graphs = "1" IfElse = "0.1" JuliaVariables = "0.2" MLStyle = "0.3,0.4" MacroTools = "0.5" MappedArrays = "0.3, 0.4" -MeasureBase = "0.9" +MeasureBase = "0.12.2" MeasureTheory = "0.16" +MetaGraphsNext = "0.3" NamedTupleTools = "0.12, 0.13, 0.14" NestedTuples = "0.3" RecipesBase = "1" @@ -63,7 +69,7 @@ StatsFuns = "0.9, 1" TransformVariables = "0.5, 0.6" Tricks = "0.1" TupleVectors = "0.1" -julia = "1.5" +julia = "1.6" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index a32aeef..7a322e0 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -1,8 +1,6 @@ [deps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -Optim = "429524aa-4258-5aef-a3af-852621145aeb" Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/benchmarks/bouncy.jl b/benchmarks/bouncy.jl index 53b41d2..c22cf86 100644 --- a/benchmarks/bouncy.jl +++ b/benchmarks/bouncy.jl @@ -14,56 +14,73 @@ using ForwardDiff using ForwardDiff: Dual using Pathfinder using Pathfinder.PDMats +using MCMCChains +using TupleVectors: chainvec +using Tilde.MeasureTheory: transform Random.seed!(1) +function make_grads(post) + as_post = as(post) + d = TV.dimension(as_post) + obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ)) + ℓ(θ) = -obj(θ) + @inline function dneglogp(t, x, v, args...) # two directional derivatives + f(t) = obj(x + t * v) + u = ForwardDiff.derivative(f, Dual{:hSrkahPmmC}(0.0, 1.0)) + u.value, u.partials[] + end + + gconfig = ForwardDiff.GradientConfig(obj, rand(d), ForwardDiff.Chunk{d}()) + function ∇neglogp!(y, t, x, args...) + ForwardDiff.gradient!(y, obj, x, gconfig) + y + end + ℓ, dneglogp, ∇neglogp! +end + +# ↑ general purpose +############################################################ +# ↓ problem-specific + # read data function readlrdata() fname = joinpath("lr.data") z = readdlm(fname) - A = z[:, 1:end-1] + A = z[:, 1:(end-1)] A = [ones(size(A, 1)) A] y = z[:, end] .- 1 return A, y end -A, y = readlrdata(); -At = collect(A'); model_lr = @model (At, y, σ) begin d, n = size(At) θ ~ Normal(σ = σ)^d for j in 1:n - logitp = dot(view(At, :, j), θ) + logitp = view(At, :, j)' * θ y[j] ~ Bernoulli(logitp = logitp) end end + +# Define model arguments +A, y = readlrdata(); +At = collect(A'); σ = 100.0 -function make_grads(model_lr, At, y, σ) - post = model_lr(At, y, σ) | (; y) - as_post = as(post) - obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ)) - ℓ(θ) = -obj(θ) - @inline function dneglogp(t, x, v) # two directional derivatives - f(t) = obj(x + t * v) - u = ForwardDiff.derivative(f, Dual{:hSrkahPmmC}(0.0, 1.0)) - u.value, u.partials[] - end +# Represent the posterior +post = model_lr(At, y, σ) | (; y) - gconfig = ForwardDiff.GradientConfig(obj, rand(25), ForwardDiff.Chunk{25}()) - function ∇neglogp!(y, t, x) - ForwardDiff.gradient!(y, obj, x, gconfig) - return - end - post, ℓ, dneglogp, ∇neglogp! -end +d = TV.dimension(as(post)) -post, ℓ, dneglogp, ∇neglogp! = make_grads(model_lr, At, y, σ) -# Try things out -dneglogp(2.4, randn(25), randn(25)); -∇neglogp!(randn(25), 2.1, randn(25)); +# Make sure gradients are working +let + ℓ, dneglogp, ∇neglogp! = make_grads(post) + @show dneglogp(2.4, randn(d), randn(d)) + y = Vector{Float64}(undef, d) + @show ∇neglogp!(y, 2.1, randn(d)) + nothing +end -d = 25 # number of parameters t0 = 0.0; x0 = zeros(d); # starting point sampler # estimated posterior mean (n=100000, 797s) @@ -129,8 +146,7 @@ sampler = ZZB.NotFactSampler( ), ); -using TupleVectors: chainvec -using Tilde.MeasureTheory: transform +# @time first(Iterators.drop(tvs,1000)) function collect_sampler(t, sampler, n; progress = true, progress_stops = 20) if progress @@ -166,7 +182,6 @@ elapsed_time = @elapsed @time begin bps_samples, info = collect_sampler(as(post), sampler, n; progress = false) end -using MCMCChains bps_chain = MCMCChains.Chains(bps_samples.θ); bps_chain = setinfo(bps_chain, (; start_time = 0.0, stop_time = elapsed_time)); diff --git a/src/GG/deprecated_codes/explicit_scope.jl b/src/GG/deprecated_codes/explicit_scope.jl index 3c0d6d2..283f5f6 100644 --- a/src/GG/deprecated_codes/explicit_scope.jl +++ b/src/GG/deprecated_codes/explicit_scope.jl @@ -3,7 +3,7 @@ function scoping(ast) @match ast begin :([$(frees...)]($(args...)) -> begin $(stmts...) - end) => let stmts = map(rec, stmts), arw = :(($(args...),) -> begin + end) => let stmts = map(rec, stmts), arw = :(($(args...),) -> begin $(stmts...) end) Expr(:scope, (), Tuple(frees), (), arw) diff --git a/src/GG/deprecated_codes/static_closure_conv.jl b/src/GG/deprecated_codes/static_closure_conv.jl index 542ffa7..95a4dd2 100644 --- a/src/GG/deprecated_codes/static_closure_conv.jl +++ b/src/GG/deprecated_codes/static_closure_conv.jl @@ -39,6 +39,11 @@ function mk_closure_static(expr, toplevel::Vector{Expr}) $Closure{$glob_name,typeof(frees)}(frees) end ) + ret = :( + let frees = $closure_arg + $Closure{$glob_name,typeof(frees)}(frees) + end + ) (fn_expr, ret) end diff --git a/src/Tilde.jl b/src/Tilde.jl index 7cf0c18..afe5b4b 100644 --- a/src/Tilde.jl +++ b/src/Tilde.jl @@ -8,6 +8,8 @@ using Reexport: @reexport @reexport using MeasureTheory using MeasureBase: productmeasure, Returns +import MeasureBase: latentof, jointof, manifestof + import DensityInterface: logdensityof import DensityInterface: densityof import DensityInterface: DensityKind @@ -79,29 +81,32 @@ include("callify.jl") end end +include("config.jl") +include("lensvars.jl") include("optics.jl") include("maybe.jl") include("core/models/abstractmodel.jl") -include("core/models/astmodel/astmodel.jl") include("core/models/model.jl") include("core/dependencies.jl") include("core/utils.jl") include("core/models/closure.jl") +include("maybeobserved.jl") include("core/models/posterior.jl") -include("primitives/interpret.jl") include("distributions/iid.jl") include("primitives/rand.jl") include("primitives/logdensity.jl") -include("primitives/logdensity_rel.jl") -include("primitives/insupport.jl") +# include("primitives/logdensity_rel.jl") +# include("primitives/insupport.jl") -# include("primitives/basemeasure.jl") include("primitives/testvalue.jl") -include("primitives/testparams.jl") -include("primitives/weightedsampling.jl") +# include("primitives/testparams.jl") +# include("primitives/weightedsampling.jl") include("primitives/measures.jl") -include("primitives/basemeasure.jl") +# include("primitives/basemeasure.jl") +# include("primitives/predict.jl") +# include("primitives/dag.jl") +include("primitives/runmodel.jl") include("transforms/utils.jl") diff --git a/src/callify.jl b/src/callify.jl index 80d39b6..4f8a574 100644 --- a/src/callify.jl +++ b/src/callify.jl @@ -5,10 +5,31 @@ using MLStyle Replace every `f(args...; kwargs..)` with `mycall(f, args...; kwargs...)` """ -function callify(mycall, ast) +function callify(g, ast) leaf(x) = x function branch(f, head, args) default() = Expr(head, map(f, args)...) + + # Convert `for` to `while` + if head == :for + arg1 = args[1] + @assert arg1.head == :(=) + a,A0 = arg1.args + A0 = callify(g, A0) + @gensym temp + @gensym state + @gensym A + return quote + $A = $A0 + $temp = $call($g, iterate, $A) + while $temp !== nothing + $a, $state = $temp + $(args[2]) + $temp = $call($g, iterate, $A, $state) + end + end + end + head == :call || return default() if first(args) == :~ && length(args) == 3 @@ -16,71 +37,19 @@ function callify(mycall, ast) end # At this point we know it's a function call - length(args) == 1 && return Expr(:call, mycall, first(args)) + length(args) == 1 && return Expr(:call, call, g, first(args)) fun = args[1] arg2 = args[2] if arg2 isa Expr && arg2.head == :parameters # keyword arguments (try dump(:(f(x,y;a=1, b=2))) to see this) - return Expr(:call, mycall, arg2, fun, map(f, Base.rest(args, 3))...) + return Expr(:call, call, g, arg2, fun, map(f, Base.rest(args, 3))...) else - return Expr(:call, mycall, map(f, args)...) + return Expr(:call, call, g, map(f, args)...) end end - foldast(leaf, branch)(ast) + foldast(leaf, branch)(ast) |> MacroTools.flatten end -# struct Provenance{T,S} -# value::T -# sources::S -# end - -# getvalue(p::Provenance) = p.value -# getvalue(x) = x - -# getsources(p::Provenance) = p.sources -# getsources(x) = Set() - -# function trace_provenance(f, args...; kwargs...) -# (newargs, arg_sources) = (getvalue.(args), union(getsources.(args)...)) - -# k = keys(kwargs) -# v = values(kwargs) -# newkwargs = NamedTuple{k}(map(getvalue, v)) - -# k = keys(kwargs) -# v = values(NamedTuple(kwargs)) -# newkwargs = NamedTuple{k}(getvalue.(v)) -# kwarg_sources = union(getsources.(args)...) - -# sources = union(arg_sources, kwarg_sources) -# Provenance(f(newargs...; newkwargs), sources) -# end - -# macro call(expr) -# callify(expr) -# end - -# julia> callify(:(f(g(x,y)))) -# :(call(f, call(g, x, y))) - -# julia> callify(:(f(x; a=3))) -# :(call(f, x; a = 3)) - -# julia> callify(:(a+b)) -# :(call(+, a, b)) - -# julia> callify(:(call(f,3))) -# :(call(f, 3)) - -# f(x) = x+1 - -# @call f(2) - -# using SymbolicUtils - -# @syms x::Vector{Float64} i::Int - -# @call getindex(x,i) diff --git a/src/config.jl b/src/config.jl new file mode 100644 index 0000000..ab2a80d --- /dev/null +++ b/src/config.jl @@ -0,0 +1,2 @@ +abstract type AbstractTildeConfig end + diff --git a/src/core/models/abstractmodel.jl b/src/core/models/abstractmodel.jl index 2903cd1..e9d573a 100644 --- a/src/core/models/abstractmodel.jl +++ b/src/core/models/abstractmodel.jl @@ -15,9 +15,22 @@ N gives the Names of arguments (each a Symbol) B gives the Body, as an Expr M gives the Module where the model is defined """ -abstract type AbstractModel{A,B,M} <: AbstractTransitionKernel end +abstract type AbstractModel{A,B,M,P} <: AbstractTransitionKernel end -abstract type AbstractConditionalModel{M,Args,Obs} <: AbstractMeasure end +abstract type AbstractConditionalModel{M,Args,Obs,P} <: AbstractMeasure end + +# getproj(::Type{T}) where {T} = Base.Fix1(project_joint, T) +getproj(::Type{<:AbstractConditionalModel{M,Args,Obs,P}}) where {M,Args,Obs,P} = MeasureBase.instance(P) + +getproj(::M) where {M<:AbstractConditionalModel} = getproj(M) + +# project_joint(::Type{<:AbstractConditionalModel{M,A,O,typeof(first)}}, p) where {M,A,O} = first(p) +# project_joint(::Type{<:AbstractConditionalModel{M,A,O,typeof(last)}}, p) where {M,A,O} = last(p) +# project_joint(::Type{<:AbstractConditionalModel{M,A,O,typeof(identity)}}, p) where {M,A,O} = p + +# project_joint(::AbstractConditionalModel{M,A,O,typeof(first)}, p) where {M,A,O} = first(p) +# project_joint(::AbstractConditionalModel{M,A,O,typeof(last)}, p) where {M,A,O} = last(p) +# project_joint(::AbstractConditionalModel{M,A,O,typeof(identity)}, p) where {M,A,O} = p argstype(::AbstractModel{A,B,M}) where {A,B,M} = A diff --git a/src/core/models/astmodel.jl b/src/core/models/astmodel.jl new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/core/models/astmodel.jl @@ -0,0 +1 @@ + diff --git a/src/core/models/astmodel/astmodel.jl b/src/core/models/astmodel/astmodel.jl deleted file mode 100644 index ff81810..0000000 --- a/src/core/models/astmodel/astmodel.jl +++ /dev/null @@ -1,82 +0,0 @@ -struct Model{A,B,M<:GG.TypeLevel} <: AbstractModel{A,B,M} - args::Vector{Symbol} - body::Expr -end - -function Model(theModule::Module, args::Vector{Symbol}, body::Expr) - A = NamedTuple{Tuple(args)} - - B = to_type(body) - M = to_type(theModule) - return Model{A,B,M}(args, body) -end - -model(m::Model) = m - -# ModelClosure{A,B,M,Args,Obs} <: AbstractModel{A,B,M,Argvals,Obs} -# model::Model{A,B,M} -# argvals :: Argvals -# obs :: Obs -# end - -function Base.convert(::Type{Expr}, m::Model) - numArgs = length(m.args) - args = if numArgs == 1 - m.args[1] - elseif numArgs > 1 - Expr(:tuple, [x for x in m.args]...) - end - - body = m.body - - q = if numArgs == 0 - @q begin - @model $body - end - else - @q begin - @model $(args) $body - end - end - - striplines(q).args[1] -end - -Base.show(io::IO, m::Model) = println(io, convert(Expr, m)) - -function type2model(::Type{Model{A,B,M}}) where {A,B,M} - args = [fieldnames(A)...] - body = from_type(B) - Model(from_type(M), convert(Vector{Symbol}, args), body) -end - -# julia> using Tilde, MeasureTheory - -# julia> m = @model begin -# p ~ Uniform() -# x ~ Bernoulli(p) |> iid(3) -# end; - -# julia> f = interpret(m); - -# julia> f(NamedTuple()) do x,d,ctx -# r = rand(d) -# (r, merge(ctx, NamedTuple{(x,)}((r,)))) -# end -# (p = 0.3863623559358842, x = Bool[0, 0, 0]) - -# julia> f(0) do x,d,n -# r = rand(d) -# (r, n+1) -# end -# 2 - -# julia> f -# function = (_tilde, _ctx0;) -> begin -# begin -# _ctx = _ctx0 -# (p, _ctx) = _tilde(:p, (Main).Uniform(), _ctx) -# (x, _ctx) = _tilde(:x, (Main).:|>((Main).Bernoulli(p), (Main).iid(3)), _ctx) -# return _ctx -# end -# end diff --git a/src/core/models/closure.jl b/src/core/models/closure.jl index ebfd6b3..16a54c0 100644 --- a/src/core/models/closure.jl +++ b/src/core/models/closure.jl @@ -1,8 +1,25 @@ -struct ModelClosure{M,A} <: AbstractConditionalModel{M,A,NamedTuple{(),Tuple{}}} +struct ModelClosure{M,V,P} <: AbstractConditionalModel{M,V,NamedTuple{(),Tuple{}},P} model::M - argvals::A + argvals::V end +setproj(m::AbstractMeasure, ::typeof(first)) = latentof(m) +setproj(m::AbstractMeasure, ::typeof(last)) = manifestof(m) +setproj(m::AbstractMeasure, ::typeof(identity)) = jointof(m) + +function setproj(c::ModelClosure{M,V}, f::F) where {M,V,F} + setproj(model(c), f)(argvals(c)) +end + +for f in [first, last, identity] + @eval begin + function setproj(c::ModelClosure{M,V}, ::typeof($f)) where {M,V} + setproj(model(c), $f)(argvals(c)) + end + end +end + + function Base.show(io::IO, mc::ModelClosure) println(io, "ModelClosure given") println(io, " arguments ", keys(argvals(mc))) @@ -23,9 +40,7 @@ end model(c::ModelClosure) = c.model -ModelClosure(m::AbstractModel) = ModelClosure(m, NamedTuple()) - -(m::AbstractModel)(nt::NamedTuple) = ModelClosure(m, nt) +(m::AbstractModel{A,B,M,P})(nt::NT) where {A,B,M,P,NT<:NamedTuple} = ModelClosure{Model{A,B,M,P}, NT, P}(m,nt) (mc::ModelClosure)(nt::NamedTuple) = ModelClosure(model(mc), merge(mc.argvals, nt)) @@ -36,5 +51,3 @@ obstype(::ModelClosure) = NamedTuple{(),Tuple{}} obstype(::Type{<:ModelClosure}) = NamedTuple{(),Tuple{}} type2model(::Type{MC}) where {M,MC<:ModelClosure{M}} = type2model(M) - -MeasureBase.condition(m::ModelClosure, nt::NamedTuple) = ModelPosterior(m, nt) diff --git a/src/core/models/conditional.jl b/src/core/models/conditional.jl deleted file mode 100644 index 102657b..0000000 --- a/src/core/models/conditional.jl +++ /dev/null @@ -1,49 +0,0 @@ -struct ModelClosure{M,A} <: AbstractModel{A,B} - model::M - argvals::A -end - -function Base.show(io::IO, cm::ModelClosure) - println(io, "ModelClosure given") - println(io, " arguments ", keys(argvals(cm))) - println(io, " observations ", keys(observations(cm))) - println(io, model(cm)) -end - -export argvals -argvals(c::ModelClosure) = c.argvals -argvals(c::ModelPosterior) = c.argvals -argvals(m::AbstractModel) = NamedTuple() - -export observations -observations(c::ModelClosure) = c.obs - -export observed -function observed(cm::ModelClosure{M,A,O}) where {M,A,O} - keys(schema(Obs)) -end - -model(c::ModelClosure) = c.model - -ModelClosure(m::AbstractModel) = ModelClosure(m, NamedTuple(), NamedTuple()) -model(::Type{<:ModelPosterior{M,A,O}}) where {M,A,O} = type2model(Model{M,A,O}) - -ModelPosterior(m::Model) = ModelPosterior(m, NamedTuple(), NamedTuple()) - -(m::AbstractModel)(nt::NamedTuple) = ModelClosure(m)(nt) - -(cm::ModelClosure)(nt::NamedTuple) = ModelClosure(cm.model, merge(cm.argvals, nt), cm.obs) - -(m::AbstractModel)(; argvals...) = m((; argvals...)) - -(m::AbstractModel)(args...) = m(NamedTuple{Tuple(m.args)}(args...)) -# (m::Model)(args...) = m(NamedTuple{Tuple(m.args)}(args)) -(m::Model)(args...) = m(argstype(m)(args)) - -import Base - -Base.:|(m::AbstractModel, nt::NamedTuple) = ModelClosure(m) | nt - -function Base.:|(cm::ModelClosure, nt::NamedTuple) - ModelClosure(cm.model, cm.argvals, merge(cm.obs, nt)) -end diff --git a/src/core/models/model.jl b/src/core/models/model.jl index ed4fe58..147eb1b 100644 --- a/src/core/models/model.jl +++ b/src/core/models/model.jl @@ -1,20 +1,90 @@ +struct Model{A,B,M<:GG.TypeLevel,P} <: AbstractModel{A,B,M,P} + args::Vector{Symbol} + body::Expr + jointproj::P +end + +function Model(theModule::Module, args::Vector{Symbol}, body::Expr) + A = NamedTuple{Tuple(args)} + B = to_type(body) + M = to_type(theModule) + return Model{A,B,M,typeof(last)}(args, body, last) +end + +export latentof, manifestof, jointof + + + +for f in [first, last, identity] + @eval begin + function setproj(m::Model{A,B,M}, ::typeof($f)) where {A,B,M} + Model{A,B,M,typeof($f)}(m.args, m.body, $f) + end + end +end + +setproj(m::Model{A,B,M}, f::F) where {A,B,M,F} = Model{A,B,M,F}(m.args, m.body, f) + +latentof(m::AbstractModel) = setproj(m, first) +manifestof(m::AbstractModel) = setproj(m, last) +jointof(m::AbstractModel) = setproj(m, identity) + +latentof(m::AbstractConditionalModel) = setproj(m, first) +manifestof(m::AbstractConditionalModel) = setproj(m, last) +jointof(m::AbstractConditionalModel) = setproj(m, identity) + + +model(m::Model) = m +model(::Type{M}) where {M} = type2model(M) + +function Base.convert(::Type{Expr}, m::Model) + numArgs = length(m.args) + args = if numArgs == 1 + m.args[1] + elseif numArgs > 1 + Expr(:tuple, [x for x in m.args]...) + end + + body = m.body + + q = if numArgs == 0 + @q begin + @model $body + end + else + @q begin + @model $(args) $body + end + end + + striplines(q).args[1] +end + +Base.show(io::IO, m::Model) = println(io, convert(Expr, m)) + +function type2model(::Type{Model{A,B,M,P}}) where {A,B,M,P} + args = Symbol[fieldnames(A)...] + body = from_type(B) + jointproj = P.instance + Model{A,B,M,P}(args, body, jointproj) +end toargs(vs::Vector{Symbol}) = Tuple(vs) toargs(vs::NTuple{N,Symbol} where {N}) = vs macro model(vs::Expr, expr::Expr) - theModule = __module__ @assert vs.head == :tuple @assert expr.head == :block - Model(theModule, Vector{Symbol}(vs.args), expr) + ex = macroexpand(__module__, expr) + Model(__module__, Vector{Symbol}(vs.args), ex) end macro model(v::Symbol, expr::Expr) - theModule = __module__ - Model(theModule, [v], expr) + ex = macroexpand(__module__, expr) + Model(__module__, [v], ex) end macro model(expr::Expr) - theModule = __module__ - Model(theModule, Vector{Symbol}(), expr) + ex = macroexpand(__module__, expr) + Model(__module__, Vector{Symbol}(), ex) end diff --git a/src/core/models/posterior.jl b/src/core/models/posterior.jl index fb48ee7..18ecc8e 100644 --- a/src/core/models/posterior.jl +++ b/src/core/models/posterior.jl @@ -1,8 +1,12 @@ -struct ModelPosterior{M,A,O} <: AbstractConditionalModel{M,A,O} - closure::ModelClosure{M,A} +struct ModelPosterior{M,V,O,P} <: AbstractConditionalModel{M,V,O,P} + closure::ModelClosure{M,V,P} obs::O end +function setproj(p::ModelPosterior{M,V,O}, f::F) where {M,V,O,F} + ModelPosterior{M,V,O,F}(setproj(p.closure, f), observations(p)) +end + model(post::ModelPosterior) = model(post.closure) function Base.show(io::IO, cm::ModelPosterior) @@ -41,3 +45,5 @@ end function Base.:|(post::ModelPosterior, nt::NamedTuple) ModelPosterior(post.closure, merge(post.obs, nt)) end + +MeasureBase.condition(m::MC, nt::NT) where {M,V,P,MC<:ModelClosure{M,V,P},NT<:NamedTuple} = ModelPosterior{M, V, NT, P}(m, nt) diff --git a/src/core/utils.jl b/src/core/utils.jl index 60b1878..df64cd5 100644 --- a/src/core/utils.jl +++ b/src/core/utils.jl @@ -101,10 +101,6 @@ end import MacroTools: striplines, @q -# function arguments(model::DAGModel) -# model.args -# end - allequal(xs) = all(xs[1] .== xs) # # fold example usage: @@ -130,20 +126,19 @@ allequal(xs) = all(xs[1] .== xs) # # (s = [0.545324, 0.281332, 0.418541, 0.485946], a = 2.217762640580984) # From https://github.com/thautwarm/MLStyle.jl/issues/66 -@active LamExpr(x) begin - @match x begin - :($a -> begin - $(bs...) - end) => let exprs = filter(x -> !(x isa LineNumberNode), bs) - if length(exprs) == 1 - (a, exprs[1]) - else - (a, Expr(:block, bs...)) - end - end - _ => nothing - end -end +# @active LamExpr(x) begin +# @match x begin +# :($a -> begin $(bs...) end) => +# let exprs = filter(x -> !(x isa LineNumberNode), bs) +# if length(exprs) == 1 +# (a, exprs[1]) +# else +# (a, Expr(:block, bs...)) +# end +# end +# _ => nothing +# end +# end # using BenchmarkTools # f(;kwargs...) = kwargs[:a] + kwargs[:b] @@ -156,35 +151,12 @@ end # @__MODULE__ # names -# getprototype(::Type{NamedTuple{(),Tuple{}}}) = NamedTuple() -getprototype(::Type{NamedTuple{N,T} where {T<:Tuple}}) where {N} = NamedTuple{N} -getprototype(::NamedTuple{N,T} where {T<:Tuple}) where {N} = NamedTuple{N} - -function loadvals(argstype, obstype) - args = getntkeys(argstype) - obs = getntkeys(obstype) - loader = @q begin end - - for k in args - push!(loader.args, :($k = _args.$k)) - end - for k in obs - push!(loader.args, :($k = _obs.$k)) - end - - src -> (@q begin - $loader - $src - end) |> MacroTools.flatten -end - function loadvals(argstype, obstype, parstype) args = schema(argstype) data = schema(obstype) pars = schema(parstype) - loader = @q begin - end + loader = @q begin end for k in keys(args) ∪ keys(pars) ∪ keys(data) push!(loader.args, :(local $k)) @@ -223,11 +195,6 @@ function loadvals(argstype, obstype, parstype) end) |> MacroTools.flatten end -getntkeys(::NamedTuple{A,B}) where {A,B} = A -getntkeys(::Type{NamedTuple{A,B}}) where {A,B} = A -getntkeys(::Type{NamedTuple{A}}) where {A} = A -getntkeys(::Type{LazyMerge{X,Y}}) where {X,Y} = Tuple(getntkeys(X) ∪ getntkeys(Y)) - # This is just handy for REPLing, no direct connection to Tilde # julia> tower(Int) @@ -378,3 +345,57 @@ narrow_array(x) = collect(Base.Generator(identity, x)) function parse_optic(ex) unescape.(Accessors.parse_obj_optic(ex)) end + +Base.@pure function merge_names(an::Tuple{Vararg{Symbol}}, bn::Tuple{Vararg{Symbol}}) + @nospecialize an bn + names = Symbol[an...] + for n in bn + if !sym_in(n, an) + push!(names, n) + end + end + (names...,) +end + +Base.@pure function merge_types(names::Tuple{Vararg{Symbol}}, a::Type{<:NamedTuple}, b::Type{<:NamedTuple}) + @nospecialize names a b + bn = _nt_names(b) + return Tuple{Any[ fieldtype(sym_in(names[n], bn) ? b : a, names[n]) for n in 1:length(names) ]...} +end + + +@generated function mymerge(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} + names = Base.merge_names(an, bn) + types = Base.merge_types(names, a, b) + vals = Any[ :(getfield($(Base.sym_in(names[n], bn) ? :b : :a), $(QuoteNode(names[n])))) for n in 1:length(names) ] + quote + # $(Expr(:meta, :inline)) + NamedTuple{$names,$types}(($(vals...),))::NamedTuple{$names,$types} + end +end + + + +abstract type MayReturn end +struct HasReturn <: MayReturn end +struct NoReturn <: MayReturn end + +export hasreturn + +# These work just fine without the `@generated` but take *much* longer +# (92μs vs 1.3ns on a small model) +@generated function hasreturn(::M) where {M<:AbstractModel} + _hasreturn(body(M)) ? HasReturn() : NoReturn() +end + +@generated function hasreturn(::M) where {M<:AbstractConditionalModel} + _hasreturn(body(model(M))) ? HasReturn() : NoReturn() +end + + +_hasreturn(x) = false + +function _hasreturn(ast::Expr) + ast.head == :return && return true + return any(_hasreturn, ast.args) +end diff --git a/src/lensvars.jl b/src/lensvars.jl new file mode 100644 index 0000000..ec41c0f --- /dev/null +++ b/src/lensvars.jl @@ -0,0 +1,24 @@ +# Identify variables with a non-trivial lens +function lensvars(ast) + result = Symbol[] + leaf(x;kwargs...) = nothing + + function branch(f, head, args; kwargs...) + @match Expr(head, args...) begin + :(($x, $l) ~ $rhs) => begin + @match l begin + :((Accessors.opticcompose)()) => nothing + :(identity) => nothing + _ => push!(result, x) + end + end + _ => begin + foreach(f, args) + end + end + + return result + end + + foldast(leaf, branch)(opticize(ast)) +end \ No newline at end of file diff --git a/src/maybeobserved.jl b/src/maybeobserved.jl new file mode 100644 index 0000000..96abb12 --- /dev/null +++ b/src/maybeobserved.jl @@ -0,0 +1,16 @@ +abstract type MaybeObserved{N,T} end + +struct Observed{N,T} <: MaybeObserved{N,T} + value::T +end + +Observed{N}(x::T) where {N,T} = Observed{N,T}(x) + +struct Unobserved{N,T} <: MaybeObserved{N,T} + value::T +end + +Unobserved{N}(x::T) where {N,T} = Unobserved{N,T}(x) +NamedTuple(o::MaybeObserved{N,T}) where {N,T} = NamedTuple{(N,)}((o.value,)) + +value(obj::MaybeObserved) = obj.value diff --git a/src/optics.jl b/src/optics.jl index 7ddaea7..8de996a 100644 --- a/src/optics.jl +++ b/src/optics.jl @@ -43,6 +43,17 @@ end end end +# @inline function _setindex!(o::AbstractArray{T}, val::T, l::Lens!!{<:IndexLens}) where {T} +# setindex!(o, val, l.pure.indices...) +# end + +# # Attempting to set a value outside the current eltype widens the eltype +# @inline function _setindex!(o::AbstractArray{T}, val::V, l::Lens!!{<:IndexLens}) where {T,V} +# new_o = similar(o, Union{T,V}) +# new_o .= o +# setindex!(new_o, val, l.pure.indices...) +# end + @inline function Accessors.modify(f, o, l::Lens!!) set(o, l, f(l(o))) end diff --git a/src/primitives/basemeasure.jl b/src/primitives/basemeasure.jl index bf5eeae..b3202b0 100644 --- a/src/primitives/basemeasure.jl +++ b/src/primitives/basemeasure.jl @@ -1,43 +1,20 @@ import MeasureBase: basemeasure @inline function basemeasure(m::AbstractConditionalModel, pars; ctx = NamedTuple()) - gg_call(basemeasure, m, pars, NamedTuple(), ctx, (r, ctx) -> ctx) + runmodel(basemeasure, m, pars, NamedTuple(), ctx, (r, ctx) -> ctx) end @inline function tilde( ::typeof(basemeasure), + x::MaybeObserved{X}, lens, - xname, - x, d, cfg, ctx::NamedTuple, - _, - ::True, -) - xname = dynamic(xname) - xparent = getproperty(cfg.obs, xname) +) where {X} + xparent = getproperty(cfg.obs, X) x = lens(xparent) b = basemeasure(d, x) - ctx = merge(ctx, NamedTuple{(xname,)}((b,))) - (x, ctx, productmeasure(ctx)) -end - -@inline function tilde( - ::typeof(basemeasure), - lens, - xname, - x, - d, - cfg, - ctx::NamedTuple, - _, - ::False, -) - xname = dynamic(xname) - xparent = getproperty(cfg.pars, xname) - x = getproperty(cfg.pars, xname) - b = basemeasure(d, x) - ctx = merge(ctx, NamedTuple{(xname,)}((b,))) + ctx = merge(ctx, NamedTuple{(X,)}((b,))) (x, ctx, productmeasure(ctx)) end diff --git a/src/primitives/dag.jl b/src/primitives/dag.jl new file mode 100644 index 0000000..c45186a --- /dev/null +++ b/src/primitives/dag.jl @@ -0,0 +1,89 @@ +abstract type AbstractContext end + +struct GenericContext{T,M} <: AbstractContext + value::T + meta::M +end + +struct EmptyMeta end + +context(value, meta) = GenericContext(value, meta) +context(value) = GenericContext(value, EmptyMeta()) + +context_value(ctx::AbstractContext) = ctx.value +context_meta(ctx::AbstractContext) = ctx.meta + +export getdag + +using Graphs +using MetaGraphsNext + +struct MarkovContext{T} <: AbstractContext + value::T + meta::Set{Tuple{Symbol,Any}} +end + +function MarkovContext(ctx::MarkovContext, m::Set{Tuple{Symbol,Any}}) + newset = union(ctx.meta, m) + MarkovContext(ctx.value, newset) +end + +function MarkovContext(ctx::MarkovContext, m::Set{Tuple{Symbol,T}}) where {T} + newset = union(ctx.meta, Set{Tuple{Symbol,Any}}([m])) + MarkovContext(ctx.value, newset) +end + +function Base.show(io::IO, mc::MarkovContext) + print(io, "MarkovContext(", mc.value, ", ", mc.meta, ")") +end + +function markovinate(nt::NamedTuple{N,T}) where {N,T} + vals = tuple( + ( + MarkovContext(v, Set{Tuple{Symbol,Any}}([(k, identity)])) for + (k, v) in pairs(nt) + )..., + ) + NamedTuple{N}(vals) +end + +MarkovContext(x::MarkovContext) = x +MarkovContext(x) = MarkovContext(x, Set{Tuple{Symbol,Any}}()) + +markov_value(x) = x +markov_parents(x) = Set{Tuple{Symbol,Any}}() + +markov_value(x::MarkovContext) = x.value +markov_parents(x::MarkovContext) = x.meta + +function getdag(m::AbstractConditionalModel, pars) + cfg = NamedTuple() + pars = markovinate(pars) + ctx = (dag = MetaGraph(DiGraph(), Label = Tuple{Symbol,Any}),) + ctx = runmodel(getdag, m, pars, cfg, ctx, (r, ctx) -> ctx) + return ctx.dag +end + +# When a Tilde primitive `f` is called, every `g(args...)` is converted to +# `call(f, g, args...)` +function call(::typeof(getdag), g, args...) + val = g(map(markov_value, args)...) + parents = if isempty(args) + Set{Tuple{Symbol,Any}}() + else + union(map(markov_parents, args)...) + end + MarkovContext(val, parents) +end + +@inline function tilde(::typeof(getdag), x::MaybeObserved{X}, lens, d, pars, ctx) where {X} + dag = ctx.dag + for p in markov_parents(d) + # Make sure vertices exist + dag[p] = nothing + dag[(X, lens)] = nothing + # Add a new edge in the DAG + dag[p, (X, lens)] = nothing + end + (MarkovContext(value(x), Set{Tuple{Symbol,Any}}([(X, lens)])), ctx, dag) +end diff --git a/src/primitives/insupport.jl b/src/primitives/insupport.jl index 71d869d..228c53e 100644 --- a/src/primitives/insupport.jl +++ b/src/primitives/insupport.jl @@ -2,5 +2,5 @@ import MeasureBase: insupport export insupport @inline function insupport(m::AbstractConditionalModel, x::NamedTuple) - mapreduce(insupport, (a, b) -> a && b, measures!(m, x), x) + mapreduce(insupport, (a, b) -> a && b, measures(m), x) end diff --git a/src/primitives/logdensity.jl b/src/primitives/logdensity.jl index ce44c37..797de81 100644 --- a/src/primitives/logdensity.jl +++ b/src/primitives/logdensity.jl @@ -1,4 +1,3 @@ - export logdensityof using NestedTuples: lazymerge @@ -6,47 +5,108 @@ import MeasureTheory using Accessors +struct LogdensityConfig{F} <: AbstractTildeConfig + f::F +end + +@inline retfun(cfg::LogdensityConfig{typeof(logdensityof)}, r, ctx) = ctx.ℓ +@inline retfun(cfg::LogdensityConfig{typeof(unsafe_logdensityof)}, r, ctx) = ctx.ℓ + + @inline function MeasureBase.logdensityof( - cm::AbstractConditionalModel, + cm::AbstractConditionalModel{M,A,O,typeof(first)}, pars::NamedTuple; - cfg = NamedTuple(), - ctx = NamedTuple(), - retfun = (r, ctx) -> ctx.ℓ, -) +) where {M,A,O} # cfg = merge(cfg, (pars=pars,)) - ctx = merge(ctx, (ℓ = 0.0,)) - gg_call(logdensityof, cm, pars, cfg, ctx, retfun) + cfg = LogdensityConfig(logdensityof) + runmodel(cfg, cm, pars, (ℓ=0.0,)) end -@inline function tilde(::typeof(logdensityof), lens, xname, x, d, cfg, ctx::NamedTuple) - x = x.value +@inline function tilde( + cfg::LogdensityConfig{typeof(logdensityof)}, + x::MaybeObserved{X}, + lens, + d, + ctx::NamedTuple, +) where {X} + x = value(x) insupport(d, lens(x)) || return (x, ctx, ReturnNow(-Inf)) @reset ctx.ℓ += MeasureBase.unsafe_logdensityof(d, lens(x)) - (x, ctx, nothing) + (x, ctx) end @inline function MeasureBase.unsafe_logdensityof( - cm::AbstractConditionalModel, + cm::AbstractConditionalModel{M,A,O,typeof(first)}, pars::NamedTuple; - cfg = NamedTuple(), - ctx = NamedTuple(), - retfun = (r, ctx) -> ctx.ℓ, -) - # cfg = merge(cfg, (pars=pars,)) - ctx = merge(ctx, (ℓ = 0.0,)) - gg_call(unsafe_logdensityof, cm, pars, cfg, ctx, retfun) +) where {M,A,O} + cfg = LogdensityConfig(unsafe_logdensityof) + runmodel(cfg, cm, pars, (ℓ = 0.0,)) end @inline function tilde( - ::typeof(unsafe_logdensityof), + cfg::LogdensityConfig{typeof(unsafe_logdensityof)}, + x::MaybeObserved{X}, lens, - xname, - x, d, - cfg, ctx::NamedTuple, -) - x = x.value - @reset ctx.ℓ += MeasureBase.unsafe_logdensityof(d, lens(x)) - (x, ctx, ctx.ℓ) +) where {X} + x = value(x) + @reset ctx.ℓ += MeasureBase.unsafe_logdensityof(latentof(d), lens(x)) + (x, ctx) +end + +############################################################################### +# If a model has no return value, there's no need to take `latentof` + +@inline function MeasureBase.logdensityof( + cm::AbstractConditionalModel{M,A,O,typeof(last)}, + pars::NamedTuple; +) where {M,A,O} + _logdensityof(cm, pars, hasreturn(cm)) end + +@inline function _logdensityof( + cm::AbstractConditionalModel{M,A,O,typeof(last)}, + pars::NamedTuple, + ::HasReturn +) where {M,A,O} + @error """ + `logdensity` on Tilde models requires a latent space. Try + `logdensityof(latentof(...), pars)`. + """ +end + +@inline function _logdensityof( + cm::AbstractConditionalModel{M,A,O,typeof(last)}, + pars::NamedTuple, + ::NoReturn +) where {M,A,O} + # cfg = merge(cfg, (pars=pars,)) + cfg = LogdensityConfig(logdensityof) + runmodel(cfg, latentof(cm), pars, (ℓ=0.0,)) +end + +############################################################################### +# Methods that throw errors + + +@inline function MeasureBase.logdensityof( + cm::AbstractConditionalModel, + pars::NamedTuple; +) + @error """ + `logdensity` on Tilde models requires a latent space. Try + `logdensityof(latentof(...), pars)`. + """ +end + +@inline function MeasureBase.unsafe_logdensityof( + cm::AbstractConditionalModel, + pars::NamedTuple; +) + @error """ + `unsafe_logdensity` on Tilde models requires a latent space. Try + `unsafe_logdensityof(latentof(...), pars)`. + """ +end + diff --git a/src/primitives/measures.jl b/src/primitives/measures.jl index 37f3e9d..75235bc 100644 --- a/src/primitives/measures.jl +++ b/src/primitives/measures.jl @@ -42,6 +42,13 @@ # Normal(μ = -3.75905,) # """ +struct MeasuresConfig{P} <: AbstractTildeConfig + pars::P +end + +@inline retfun(cfg::MeasuresConfig, r, ctx) = ctx + + export measures @inline function measures(m::AbstractConditionalModel, pars::NamedTuple{N,T}) where {N,T} @@ -51,9 +58,10 @@ export measures end sim(x) = x + cfg= MeasuresConfig(pars) ctx = rmap(sim, pars) - nt = gg_call(measures, m, pars, NamedTuple(), ctx, (r, ctx) -> ctx) + nt = runmodel(cfg, latentof(m), pars, ctx) f(x::AbstractArray) = productmeasure(narrow_array(x)) f(x) = x @@ -61,32 +69,21 @@ export measures rmap(f, nt) end -@inline function tilde( - ::typeof(measures), - ::typeof(identity), - xname, - ::Unobserved, - d, - cfg, - ctx, -) +@inline function tilde(cfg::MeasuresConfig, x::Unobserved{X}, d, ctx) where {X} x = testvalue(d) - xname = dynamic(xname) - ctx = merge(ctx, NamedTuple{(xname,)}((d,))) - (x, ctx, ctx) + ctx = merge(ctx, NamedTuple{X}((d,))) + (x, ctx) end -@inline function tilde(::typeof(measures), lens, xname, x::Unobserved, d, cfg, ctx) - xname = dynamic(xname) - ctx = set(ctx, PropertyLens{xname}() ⨟ Lens!!(lens), d) +@inline function tilde(cfg::MeasuresConfig, x::Unobserved{X}, lens, d, ctx) where {X} + ctx = set(ctx, PropertyLens{X}() ⨟ Lens!!(lens), d) - xnew = getproperty(cfg.pars, xname) - (xnew, ctx, ctx) + xnew = getproperty(cfg.pars, X) + (xnew, ctx) end -@inline function tilde(::typeof(measures), lens, xname, x::Observed, d, cfg, ctx) - x = x.value - (x, ctx, ctx) +@inline function tilde(cfg::MeasuresConfig, x::Observed{X}, lens, d, ctx) where {X} + (value(x), ctx) end function as(mdl::AbstractConditionalModel) @@ -94,4 +91,4 @@ function as(mdl::AbstractConditionalModel) as(map(as, ms)) end -measures(m) = measures(m, testvalue(m)) +measures(m) = measures(m, testvalue(latentof(m))) diff --git a/src/primitives/predict.jl b/src/primitives/predict.jl new file mode 100644 index 0000000..0558177 --- /dev/null +++ b/src/primitives/predict.jl @@ -0,0 +1,147 @@ +using Random: GLOBAL_RNG, AbstractRNG +using TupleVectors +export predict + + +anyfy(x) = x +anyfy(x::AbstractArray) = collect(Any, x) + +function anyfy(mc::ModelClosure) + m = model(mc) + a = rmap(anyfy, argvals(mc)) + m(a) +end + +function anyfy(mp::ModelPosterior) + m = model(mp) + a = rmap(anyfy, argvals(mp)) + o = rmap(anyfy, observations(mp)) + m(a) | o +end + +############################################################################### +# `predict` for forward random sampling + +@inline function predict(m::AbstractConditionalModel, pars) + predict(GLOBAL_RNG, m, pars) +end + +@inline function predict(rng::AbstractRNG, m::AbstractConditionalModel, pars::NamedTuple) + predict_rand(rng::AbstractRNG, m::AbstractConditionalModel, pars) +end + +@inline function predict_rand(rng::AbstractRNG, m::AbstractConditionalModel, pars) + cfg = (rng = rng, pars = pars) + ctx = NamedTuple() + runmodel(predict_rand, m, pars, cfg, ctx, (r, ctx) -> r) +end + + +@inline function tilde(::typeof(predict_rand), x, lens, d, cfg, ctx) + tilde_predict(cfg.rng, x, lens, d, cfg.pars, ctx) +end + + + +@generated function tilde_predict( + rng, + x::MaybeObserved{X}, + lens, + d, + pars::NamedTuple{N}, + ctx, +) where {X,N} + if X ∈ N + quote + # @info "$X ∈ N" + xnew = set(value(x), Lens!!(lens), lens(getproperty(pars, X))) + # ctx = merge(ctx, NamedTuple{(X,)}((xnew,))) + (xnew, ctx) + end + else + quote + # @info "$X ∉ N" + xnew = set(value(x), Lens!!(lens), rand(rng, d)) + ctx = merge(ctx, NamedTuple{(X,)}((xnew,))) + (xnew, ctx) + end + end +end + + +############################################################################### + + + +@inline function predict(f, m::AbstractConditionalModel, pars::NamedTuple) + m = anyfy(m) + pars = rmap(anyfy, pars) + cfg = (f = f, pars = pars) + ctx = NamedTuple() + runmodel(predict, m, pars, cfg, ctx, (r, ctx) -> r) +end + +@inline function predict(f, m::AbstractConditionalModel, tv::TupleVector) + n = length(tv) + @inbounds result = chainvec(predict(f, m, tv[1]), n) + @inbounds for j in 2:n + result[j] = predict(f, m, tv[j]) + end + return result +end + +@inline function tilde(::typeof(predict), x, lens, d, cfg, ctx) + tilde_predict(cfg.f, x, lens, d, cfg.pars, ctx) +end + +# @generated function tilde_predict( +# f, +# x::Observed{X}, +# lens, +# d, +# pars::NamedTuple{N}, +# ctx, +# ) where {X,N} +# if X ∈ N +# quote +# # @info "$X ∈ N" +# xnew = set(x.value, Lens!!(lens), lens(getproperty(pars, X))) +# # ctx = merge(ctx, NamedTuple{(X,)}((xnew,))) +# (xnew, ctx, ctx) +# end +# else +# quote +# # @info "$X ∉ N" +# x = x.value +# xnew = set(copy(x), Lens!!(lens), f(d, lens(x))) +# ctx = merge(ctx, NamedTuple{(X,)}((xnew,))) +# (xnew, ctx, ctx) +# end +# end +# end + +# @generated function tilde_predict( +# f, +# x::Unobserved{X}, +# lens, +# d, +# pars::NamedTuple{N}, +# ctx, +# ) where {X,N} +# if X ∈ N +# quote +# # @info "$X ∈ N" +# xnew = set(value(x), Lens!!(lens), lens(getproperty(pars, X))) +# # ctx = merge(ctx, NamedTuple{(X,)}((xnew,))) +# (xnew, ctx, ctx) +# end +# else +# quote +# # @info "$X ∉ N" +# # In this case x == Unobserved(missing) +# xnew = set(value(x), Lens!!(lens), f(d, missing)) +# ctx = merge(ctx, NamedTuple{(X,)}((xnew,))) +# (xnew, ctx, ctx) +# end +# end +# end diff --git a/src/primitives/rand.jl b/src/primitives/rand.jl index 5f16d01..f2a9f91 100644 --- a/src/primitives/rand.jl +++ b/src/primitives/rand.jl @@ -1,88 +1,145 @@ using Random: GLOBAL_RNG using TupleVectors: chainvec -export rand -EmptyNTtype = NamedTuple{(),Tuple{}} where {T<:Tuple} +struct RandConfig{T_rng, RNG, P} <: AbstractTildeConfig + rng::RNG + proj::P -@inline function Base.rand(rng::AbstractRNG, d::AbstractConditionalModel, N::Int) - r = chainvec(rand(rng, d), N) - for j in 2:N - @inbounds r[j] = rand(rng, d) - end - return r + RandConfig(::Type{T_rng}, rng::RNG, proj::P) where {T_rng, RNG<:AbstractRNG, P} = new{T_rng,RNG,P}(rng,proj) end -@inline Base.rand(d::AbstractConditionalModel, N::Int) = rand(GLOBAL_RNG, d, N) +RandConfig(rng,proj) = RandConfig(Float64, rng, proj) +RandConfig(proj) = RandConfig(Float64, Random.GLOBAL_RNG, proj) + -@inline function Base.rand(m::AbstractConditionalModel; kwargs...) - rand(GLOBAL_RNG, m; kwargs...) +@inline function retfun(cfg::RandConfig, joint::Pair, ctx) + cfg.proj(ctx => last(joint)) end + +export rand +EmptyNTtype = NamedTuple{(),Tuple{}} where {T<:Tuple} + + @inline function Base.rand( rng::AbstractRNG, - m::AbstractConditionalModel; - ctx = NamedTuple(), - retfun = (r, ctx) -> r, -) - cfg = (rng = rng,) - gg_call(rand, m, NamedTuple(), cfg, ctx, retfun) + ::Type{T_rng}, + mc::ModelClosure +) where {T_rng} + cfg = RandConfig(T_rng, rng, getproj(mc)) + pars = NamedTuple() + ctx = NamedTuple() + runmodel(cfg, mc, pars, ctx) end ############################################################################### -# ctx::NamedTuple +# tilde + @inline function tilde( - ::typeof(Base.rand), - lens, - xname, + cfg::RandConfig{T_rng, RNG, typeof(last)}, x::Unobserved, + lens, d, - cfg, - ctx::NamedTuple, -) - xnew = set(x.value, Lens!!(lens), rand(cfg.rng, d)) - ctx′ = merge(ctx, NamedTuple{(dynamic(xname),)}((xnew,))) - (xnew, ctx′, nothing) + ctx, +) where {T_rng, RNG} + r = rand(cfg.rng, T_rng, d) + xnew = set(value(x), Lens!!(lens), r) + (xnew, ctx) end @inline function tilde( - ::typeof(Base.rand), + cfg::RandConfig{T_rng, RNG, P}, + x::Unobserved{X}, lens, - xname, - x::Observed, d, - cfg, - ctx::NamedTuple, -) - (x.value, ctx, nothing) + ctx, +) where {X,T_rng, RNG,P} + joint = rand(cfg.rng, T_rng, jointof(d)) + latent, retn = joint + xnew = set(value(x), Lens!!(lens), retn) + ctx′ = mymerge(ctx, NamedTuple{(X,)}((latent,))) + (xnew, ctx′) end + + ############################################################################### -# ctx::Dict +# Dispatch helpers -@inline function tilde( - ::typeof(Base.rand), - lens::typeof(identity), - xname, - x, - d, - cfg, - ctx::Dict, +@inline function Base.rand(m::ModelClosure, args...; kwargs...) + rand(GLOBAL_RNG, Float64, m, args...; kwargs...) +end + +@inline function Base.rand(rng::AbstractRNG, m::ModelClosure, args...; kwargs...) + rand(rng, Float64, m, args...; kwargs...) +end + +@inline function Base.rand(::Type{T_rng}, m::ModelClosure, args...; kwargs...) where {T_rng} + rand(GLOBAL_RNG, T_rng, m, args...; kwargs...) +end + + +############################################################################### +# Specifying an Integer argument creates a TupleVector + + + +@inline function Base.rand(m::ModelClosure, N::Integer; kwargs...) + rand(GLOBAL_RNG, Float64, m, N; kwargs...) +end + +@inline function Base.rand( + rng::AbstractRNG, + mc::ModelClosure, + N::Integer, + kwargs..., ) - x = rand(cfg.rng, d) - ctx[dynamic(xname)] = x - (x, ctx, nothing) + rand(rng, Float64, mc, N; kwargs...) end -@inline function tilde( - ::typeof(Base.rand), - lens, - xname, - x, - m::AbstractConditionalModel, - cfg, - ctx::Dict, + +@inline function Base.rand( + ::Type{T_rng}, + mc::ModelClosure, + N::Integer; + kwargs..., +) where {T_rng} + rand(GLOBAL_RNG, T_rng, mc, N; kwargs...) +end + + +@inline function Base.rand( + rng::AbstractRNG, + ::Type{T_rng}, + mc::ModelClosure, + N::Integer, +) where {T_rng} + r = chainvec(rand(rng, T_rng, mc), N) + for j in 2:N + @inbounds r[j] = rand(rng, T_rng, mc) + end + return r +end + +############################################################################### +# Cases that throw errors + +function Base.rand( + ::AbstractRNG, + ::Type, + m::AbstractModel, + args...; + kwargs... ) - args = get(cfg.args, dynamic(xname), Dict()) - cfg = merge(cfg, (args = args,)) - tilde(rand, lens, xname, x, m(cfg.args), cfg, ctx) + @error "`rand` called on Model without arugments. Try `m(args)` or `m()` if the model has no arguments" end + +function Base.rand( + ::AbstractRNG, + ::Type, + m::ModelPosterior, + args...; + kwargs... +) + @error "`rand` called on ModelPosterior. `rand` does not allow conditioning; try `predict`" +end \ No newline at end of file diff --git a/src/primitives/interpret.jl b/src/primitives/runmodel.jl similarity index 58% rename from src/primitives/interpret.jl rename to src/primitives/runmodel.jl index 4d6fe0d..0344573 100644 --- a/src/primitives/interpret.jl +++ b/src/primitives/runmodel.jl @@ -1,27 +1,17 @@ -export interpret +export runmodel @inline function inkeys(::StaticSymbol{s}, ::Type{NamedTuple{N,T}}) where {s,N,T} return s ∈ N end -function interpret(m::Model{A,B,M}, tilde, ctx0) where {A,B,M} - theModule = getmodule(m) - mk_function(theModule, make_body(theModule, m.body, tilde, ctx0)) -end - function make_body(M, f, m::AbstractModel) make_body(M, body(m)) end -struct Observed{T} - value::T -end - -struct Unobserved{T} - value::T -end +call(f, g, args...; kwargs...) = g(args...; kwargs...) -function make_body(M, f, ast::Expr, retfun, argsT, obsT, parsT) +function make_body(M, ast::Expr, proj, argsT, obsT, parsT, paramnames) + paramvals = Expr(:tuple, paramnames...) knownvars = union(keys.(schema.((argsT, obsT, parsT)))...) function go(ex, scope = (bounds = Var[], freevars = Var[], bound_inits = Symbol[])) @match ex begin @@ -43,28 +33,35 @@ function make_body(M, f, ast::Expr, retfun, argsT, obsT, parsT) # X = to_type(unsolved_lhs) # M = to_type(unsolve(rhs)) - inargs = inkeys(sx, argsT) + # inargs = inkeys(sx, argsT) inobs = inkeys(sx, obsT) - inpars = inkeys(sx, parsT) + # inpars = inkeys(sx, parsT) rhs = unsolve(rhs) - - xval = if inobs - :($Observed($x)) + + obj = if inobs + # TODO: Even if `x` is observed, we may have `lens(x) == missing` + :($Observed{$qx}($x)) else - (x ∈ knownvars ? :($Unobserved($x)) : :($Unobserved(missing))) + (if x ∈ knownvars + :($Unobserved{$qx}($x)) + else + :($Unobserved{$qx}(missing)) + end) end - st = :(($x, _ctx, _retn) = $tilde($f, $l, $sx, $xval, $rhs, _cfg, _ctx)) - qst = QuoteNode(st) + st = :(($x, _ctx) = $tilde(_cfg, $obj, $l, $rhs, _ctx)) + # qst = QuoteNode(st) q = quote # println($qst) $st - _retn isa Tilde.ReturnNow && return _retn.value + _ctx isa Tilde.ReturnNow && return _ctx.value end q end - :(return $r) => :(return $retfun($r, _ctx)) + :(return $r) => quote + return Tilde.retfun(_cfg, NamedTuple{$paramnames}($paramvals) => $r, _ctx) + end Expr(:scoped, new_scope, ex) => begin go(ex, new_scope) @@ -76,9 +73,10 @@ function make_body(M, f, ast::Expr, retfun, argsT, obsT, parsT) end end - body = go(@q begin - $(solve_scope(opticize(ast))) - end) |> unsolve |> MacroTools.flatten + body = + go(@q begin + $(solve_scope(opticize(ast))) + end) |> unsolve |> MacroTools.flatten body end @@ -91,14 +89,18 @@ end # error(ex) # end -@generated function gg_call( - ::F, +struct KnownVars{A,O,P} + args::A + obs::O + pars::P +end + +@generated function runmodel( + _cfg, _mc::MC, _pars::NamedTuple{N,T}, - _cfg, _ctx, - ::R, -) where {F,MC,N,T,R} +) where {MC,N,T} _m = type2model(MC) M = getmodule(_m) @@ -108,19 +110,20 @@ end body = _m.body |> loadvals(argsT, obsT, parsT) - f = MeasureBase.instance(F) - _retfun = MeasureBase.instance(R) - body = make_body(M, f, body, _retfun, argsT, obsT, parsT) + paramnames = tuple(parameters(_m)...) + paramvals = Expr(:tuple, paramnames...) + _proj = getproj(MC) + body = make_body(M, body, _proj, argsT, obsT, parsT, paramnames) q = MacroTools.flatten( - @q @inline function (_mc, _cfg, _ctx, _pars, _retfun) - local _retn + @q function (_mc, _cfg, _ctx, _pars) _args = $argvals(_mc) _obs = $observations(_mc) - _cfg = merge(_cfg, (args = _args, obs = _obs, pars = _pars)) + # _vars = KnownVars(_args, _obs, _pars) $body # If body doesn't have a return, default to `return ctx` - return $_retfun(_ctx, _ctx) + _params = NamedTuple{$paramnames}($paramvals) + return Tilde.retfun(_cfg, _params => _params, _ctx) end ) diff --git a/src/primitives/testparams.jl b/src/primitives/testparams.jl index ca9893c..7506492 100644 --- a/src/primitives/testparams.jl +++ b/src/primitives/testparams.jl @@ -6,7 +6,7 @@ EmptyNTtype = NamedTuple{(),Tuple{}} where {T<:Tuple} testparams(d::AbstractMeasure) = testvalue(d) @inline function testparams(mc::ModelClosure; cfg = NamedTuple(), ctx = NamedTuple()) - gg_call(testparams, mc, NamedTuple(), cfg, ctx, (r, ctx) -> ctx) + runmodel(testparams, mc, NamedTuple(), cfg, ctx, (r, ctx) -> ctx) end ############################################################################### @@ -14,22 +14,26 @@ end @inline function tilde( ::typeof(testparams), + x::MaybeObserved{X}, lens::typeof(identity), - xname, - x, d, cfg, ctx::NamedTuple, - _, - _, -) +) where {X} xnew = testparams(d) - ctx′ = merge(ctx, NamedTuple{(dynamic(xname),)}((xnew,))) - (xnew, ctx′, ctx′) + ctx′ = merge(ctx, NamedTuple{(X,)}((xnew,))) + (xnew, ctx′) end -@inline function tilde(::typeof(testparams), lens, xname, x, d, cfg, ctx::NamedTuple, _, _) +@inline function tilde( + ::typeof(testparams), + x::MaybeObserved{X}, + lens, + d, + cfg, + ctx::NamedTuple, +) where {X} xnew = set(x, Lens!!(lens), testparams(d)) - ctx′ = merge(ctx, NamedTuple{(dynamic(xname),)}((xnew,))) - (xnew, ctx′, ctx′) + ctx′ = merge(ctx, NamedTuple{(X,)}((xnew,))) + (xnew, ctx′) end diff --git a/src/primitives/testvalue.jl b/src/primitives/testvalue.jl index acc46b1..e0607dd 100644 --- a/src/primitives/testvalue.jl +++ b/src/primitives/testvalue.jl @@ -1,39 +1,40 @@ using TupleVectors: chainvec import MeasureTheory: testvalue +struct TestValueConfig{P} <: AbstractTildeConfig + proj::P +end + +@inline retfun(cfg::TestValueConfig, r, ctx) = cfg.proj(r) + + export testvalue EmptyNTtype = NamedTuple{(),Tuple{}} where {T<:Tuple} -@inline function testvalue( - mc::AbstractConditionalModel; - cfg = NamedTuple(), - ctx = NamedTuple(), -) - gg_call(testvalue, mc, NamedTuple(), cfg, ctx, (r, ctx) -> r) +@inline function testvalue(mc::AbstractConditionalModel) + cfg = TestValueConfig(getproj(mc)) + ctx = NamedTuple() + runmodel(cfg, mc, NamedTuple(), ctx) end @inline function tilde( - ::typeof(testvalue), + ::TestValueConfig, + x::Unobserved{X}, lens, - xname, - x::Unobserved, d, - cfg, ctx::NamedTuple, -) - xnew = set(x.value, Lens!!(lens), testvalue(d)) - ctx′ = merge(ctx, NamedTuple{(dynamic(xname),)}((xnew,))) - (xnew, ctx′, nothing) +) where {X} + xnew = set(value(x), Lens!!(lens), testvalue(d)) + ctx′ = merge(ctx, NamedTuple{(X,)}((xnew,))) + (xnew, ctx′) end @inline function tilde( - ::typeof(testvalue), + ::TestValueConfig, + x::Observed{X}, lens, - xname, - x::Observed, d, - cfg, ctx::NamedTuple, -) - (x.value, ctx, nothing) +) where {X} + (lens(value(x)), ctx) end diff --git a/src/primitives/weightedsampling.jl b/src/primitives/weightedsampling.jl index 5a90c7d..2eff099 100644 --- a/src/primitives/weightedsampling.jl +++ b/src/primitives/weightedsampling.jl @@ -13,36 +13,33 @@ end ) cfg = (rng = rng,) ctx = (ℓ = 0.0, pars = NamedTuple()) - gg_call(weightedrand, m, NamedTuple(), cfg, ctx, (r, ctx) -> ctx) + runmodel(weightedrand, m, NamedTuple(), cfg, ctx, (r, ctx) -> ctx) end @inline function tilde( ::typeof(weightedrand), + x::Observed{X}, lens, - xname, - x::Observed, d, cfg, ctx::NamedTuple, -) - x = x.value - xname = dynamic(xname) +) where {X} + x = value(x) Δℓ = logdensityof(d, lens(x)) @reset ctx.ℓ += Δℓ - (x, ctx, ctx) + (x, ctx) end @inline function tilde( ::typeof(weightedrand), + x::Unobserved{X}, lens, - xname, - x::Unobserved, d, cfg, ctx::NamedTuple, -) - xnew = set(x.value, Lens!!(lens), rand(cfg.rng, d)) - pars = merge(ctx.pars, NamedTuple{(dynamic(xname),)}((xnew,))) +) where {X} + xnew = set(value(x), Lens!!(lens), rand(cfg.rng, d)) + pars = merge(ctx.pars, NamedTuple{(X,)}((xnew,))) ctx = merge(ctx, (pars = pars,)) - (xnew, ctx, nothing) + (xnew, ctx) end diff --git a/test/runtests.jl b/test/runtests.jl index 8f57c3b..6b8ec8c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ import TransformVariables as TV using Aqua using Tilde -Aqua.test_all(Tilde; ambiguities=false, unbound_args=false) +Aqua.test_all(Tilde; ambiguities=false) include("examples-list.jl") @@ -59,8 +59,10 @@ include("examples-list.jl") y ~ Bernoulli(p) return y end - - mean(predict(m(), [(p=p,) for p in rand(10000)])) isa Float64 + + @test predict(m(), (p=rand(),)) isa Bool + + # @test mean(predict(m(), [(p=p,) for p in rand(10000)])) isa AbstractFloat end @testset "https://github.com/cscherrer/Soss.jl/issues/258" begin @@ -172,7 +174,7 @@ end -end + @testset "Nested models" begin @@ -199,3 +201,5 @@ end @test rand(m(); ctx=()) isa Bool @test logdensity(m(), rand(m())) isa Float64 end + +end \ No newline at end of file diff --git a/test/transforms.jl b/test/transforms.jl index f50f45d..ea80567 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -1,12 +1,3 @@ -# Check for Model equality up to reorderings of a few fields -function ≊(m1::DAGModel,m2::DAGModel) - function eq_tuples(nt1::NamedTuple,nt2::NamedTuple) - return length(nt1)==length(nt2) && all(nt1[k]==nt2[k] for k in keys(nt1)) - end - return Set(arguments(m1))==Set(arguments(m2)) && m1.retn==m2.retn && eq_tuples(m1.dists,m2.dists) && eq_tuples(m1.vals,m2.vals) -end - - m = @model (n,α,β) begin p ~ Beta(α, β) x ~ Binomial(n, p)