Skip to content

Dev-predict #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 100 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
d77eaa2
bouncy updates
cscherrer Jun 7, 2022
8fff1cc
rand
cscherrer Jun 11, 2022
059ab30
minor fix
cscherrer Jun 11, 2022
4d65f63
Fix method ambiguity
cscherrer Jun 11, 2022
2dd7924
refactoring
cscherrer Jun 17, 2022
939f341
Merge branch 'dev' into temp
cscherrer Jun 17, 2022
78c3df1
Merge pull request #25 from cscherrer/temp
cscherrer Jun 17, 2022
97c29c5
Merge remote-tracking branch 'origin/dev' into dev
cscherrer Jun 17, 2022
52179a9
Merge pull request #20 from cscherrer/rand
cscherrer Jun 21, 2022
dc3c376
Disallow `rand` on ModelPosteriors
cscherrer Jun 21, 2022
3c03bd4
Merge branch 'dev' of https://github.com/cscherrer/Tilde.jl into dev
cscherrer Jun 21, 2022
a352754
Move astmodel.jl contents into model.jl
cscherrer Jun 21, 2022
3921a40
Drop old `DAGModel` stuff
cscherrer Jun 21, 2022
3aaaa74
drop dead code
cscherrer Jun 21, 2022
2f832a5
change `S` type param to `F`
cscherrer Jun 21, 2022
69ce0d8
drop on include
cscherrer Jun 21, 2022
36a2e33
edit for readability
cscherrer Jun 22, 2022
781c29d
update dependencies
cscherrer Jun 22, 2022
fbef2d1
Update src/core/models/abstractmodel.jl
cscherrer Jun 22, 2022
d41b73a
Update src/core/models/abstractmodel.jl
cscherrer Jun 22, 2022
cf209dc
Update src/core/models/model.jl
cscherrer Jun 22, 2022
bedd3fa
drop the F
cscherrer Jun 22, 2022
564ec84
Merge branch 'dev' of https://github.com/cscherrer/Tilde.jl into dev
cscherrer Jun 22, 2022
40933c3
predict (#21)
cscherrer Jun 22, 2022
9eefb93
test unbound args
cscherrer Jun 23, 2022
5e12f9a
cleanup
cscherrer Jun 23, 2022
2d29d0d
bugfix
cscherrer Jun 23, 2022
07dcf28
start on refactoring PDMP example
cscherrer Jun 24, 2022
c4ce48e
comments, mostly
cscherrer Jun 24, 2022
4011b4a
bugfix
cscherrer Jun 25, 2022
dd740bd
working on `predict`
cscherrer Jun 27, 2022
afb72fc
update Lens!!
cscherrer Jun 27, 2022
3d35531
Update benchmarks/bouncy.jl
cscherrer Jun 27, 2022
39fabf0
Update benchmarks/bouncy.jl
cscherrer Jun 27, 2022
761c60e
Merge branch 'main' into dev
cscherrer Jun 27, 2022
1d81bf7
lens stuff
cscherrer Jun 27, 2022
5c2ed88
drop optics change
cscherrer Jun 27, 2022
2ce7f40
moving things around a bit
cscherrer Jun 27, 2022
a0aabb8
Contexts (#32)
cscherrer Jul 4, 2022
e6cb833
Merge branch 'main' of https://github.com/cscherrer/Tilde.jl into dev
cscherrer Jul 4, 2022
6513e51
formatting
cscherrer Jul 4, 2022
75bdc37
bugfixes
cscherrer Jul 5, 2022
d8fb3cc
bugfix
cscherrer Jul 5, 2022
fd6ce5d
fix `predict`
cscherrer Jul 5, 2022
f315d9b
MaybeObserved shoudl be Unobserved
cscherrer Jul 5, 2022
7c93602
Move `anyfy`
cscherrer Jul 6, 2022
2181564
update `callify`
cscherrer Jul 6, 2022
ee2ab1a
deps
cscherrer Jul 6, 2022
b4da1e6
predict WIP
cscherrer Jul 7, 2022
5f1bf3c
latentof, etc
cscherrer Jul 11, 2022
98f5d05
drop observed rand stuff
cscherrer Jul 11, 2022
5e819a3
update deps
cscherrer Jul 11, 2022
4cbd74f
working on predict
cscherrer Jul 11, 2022
f57f397
bugfix in tests
cscherrer Jul 11, 2022
303ab5b
bugfix
cscherrer Jul 11, 2022
08d2b0e
add a note
cscherrer Jul 11, 2022
c6136ab
fix `logdensityof`
cscherrer Jul 11, 2022
9ab1c53
working on proj
cscherrer Jul 11, 2022
411f08e
updates
cscherrer Jul 11, 2022
7ac3f85
bugfix
cscherrer Jul 11, 2022
bc11132
updates
cscherrer Jul 12, 2022
e69d41b
rand stuff
cscherrer Jul 12, 2022
00a5a03
mucking about
cscherrer Jul 12, 2022
d25f4e4
refactoring
cscherrer Jul 13, 2022
441d627
fix `rand`
cscherrer Jul 13, 2022
a1bbc84
readability
cscherrer Jul 13, 2022
14c67dd
expand macros
cscherrer Jul 14, 2022
ab953b1
inline rand
cscherrer Jul 14, 2022
3d7b69c
rand updates
cscherrer Jul 14, 2022
bba9786
rand
cscherrer Jul 14, 2022
b887d24
rand
cscherrer Jul 14, 2022
5fc4eb3
drop ctx
cscherrer Jul 14, 2022
e9d55a2
move cfg from kwarg to "regular" arg
cscherrer Jul 14, 2022
8ece68d
drop unused type parameter
cscherrer Jul 14, 2022
4e8d102
help inference
cscherrer Jul 14, 2022
96d3ccf
update `rand`
cscherrer Jul 15, 2022
3321d5e
require MeasureBase 0.12.2
cscherrer Jul 17, 2022
cfcf575
mymerge
cscherrer Jul 18, 2022
459a908
mymerge
cscherrer Jul 18, 2022
570e082
updates
cscherrer Jul 18, 2022
39050eb
Merge branch 'dev-predict' of https://github.com/cscherrer/Tilde.jl i…
cscherrer Jul 18, 2022
3a2141d
more refactoring
cscherrer Jul 18, 2022
4008ba8
Drop unneeded argument
cscherrer Jul 18, 2022
662d9a4
tidying up
cscherrer Jul 18, 2022
0194a22
update
cscherrer Jul 18, 2022
e0f9e6c
update
cscherrer Jul 19, 2022
f09b6a3
Merge branch 'dev-predict' of github.com:cscherrer/Tilde.jl into dev-…
cscherrer Jul 19, 2022
d5889eb
gg_call => runmodel
cscherrer Jul 19, 2022
9f56597
lensvars
cscherrer Jul 19, 2022
3b7d579
logdensity
cscherrer Jul 19, 2022
8d7d32c
update interpret
cscherrer Jul 19, 2022
c74ddea
interpret => runmodel
cscherrer Jul 19, 2022
bb2eff8
testvalue
cscherrer Jul 19, 2022
c8fe8fa
logdensityof
cscherrer Jul 19, 2022
b6dd098
measures
cscherrer Jul 19, 2022
3468049
`model` on types
cscherrer Jul 20, 2022
e0ef186
`hasreturn`
cscherrer Jul 20, 2022
0c7c3ce
add a comment
cscherrer Jul 20, 2022
bb4ec23
more logdensity stuff
cscherrer Jul 20, 2022
3b2df09
update rand
cscherrer Jul 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ version = "0.1.2"
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
JuliaVariables = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -17,6 +19,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14"
MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
NestedTuples = "a734d2a7-8d68-409b-9419-626914d4061d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand All @@ -38,17 +41,20 @@ TupleVectors = "615932cf-77b6-4358-adcd-5b7eba981d7e"

[compat]
Accessors = "0.1"
ArrayInterface = "4, 5, 6"
ArrayInterface = "5, 6"
ChainRulesCore = "1"
DataStructures = "0.18"
DensityInterface = "0.4"
DiffResults = "1"
Graphs = "1"
IfElse = "0.1"
JuliaVariables = "0.2"
MLStyle = "0.3,0.4"
MacroTools = "0.5"
MappedArrays = "0.3, 0.4"
MeasureBase = "0.9"
MeasureBase = "0.12.2"
MeasureTheory = "0.16"
MetaGraphsNext = "0.3"
NamedTupleTools = "0.12, 0.13, 0.14"
NestedTuples = "0.3"
RecipesBase = "1"
Expand All @@ -63,7 +69,7 @@ StatsFuns = "0.9, 1"
TransformVariables = "0.5, 0.6"
Tricks = "0.1"
TupleVectors = "0.1"
julia = "1.5"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Expand Down
73 changes: 44 additions & 29 deletions benchmarks/bouncy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,73 @@ using ForwardDiff
using ForwardDiff: Dual
using Pathfinder
using Pathfinder.PDMats
using MCMCChains
using TupleVectors: chainvec
using Tilde.MeasureTheory: transform

Random.seed!(1)

function make_grads(post)
as_post = as(post)
d = TV.dimension(as_post)
obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ))
ℓ(θ) = -obj(θ)
@inline function dneglogp(t, x, v, args...) # two directional derivatives
f(t) = obj(x + t * v)
u = ForwardDiff.derivative(f, Dual{:hSrkahPmmC}(0.0, 1.0))
u.value, u.partials[]
end

gconfig = ForwardDiff.GradientConfig(obj, rand(d), ForwardDiff.Chunk{d}())
function ∇neglogp!(y, t, x, args...)
ForwardDiff.gradient!(y, obj, x, gconfig)
y
end
ℓ, dneglogp, ∇neglogp!
end

# ↑ general purpose
############################################################
# ↓ problem-specific

# read data
function readlrdata()
fname = joinpath("lr.data")
z = readdlm(fname)
A = z[:, 1:end-1]
A = z[:, 1:(end-1)]
A = [ones(size(A, 1)) A]
y = z[:, end] .- 1
return A, y
end
A, y = readlrdata();
At = collect(A');

model_lr = @model (At, y, σ) begin
d, n = size(At)
θ ~ Normal(σ = σ)^d
for j in 1:n
logitp = dot(view(At, :, j), θ)
logitp = view(At, :, j)' * θ
y[j] ~ Bernoulli(logitp = logitp)
end
end

# Define model arguments
A, y = readlrdata();
At = collect(A');
σ = 100.0

function make_grads(model_lr, At, y, σ)
post = model_lr(At, y, σ) | (; y)
as_post = as(post)
obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ))
ℓ(θ) = -obj(θ)
@inline function dneglogp(t, x, v) # two directional derivatives
f(t) = obj(x + t * v)
u = ForwardDiff.derivative(f, Dual{:hSrkahPmmC}(0.0, 1.0))
u.value, u.partials[]
end
# Represent the posterior
post = model_lr(At, y, σ) | (; y)

gconfig = ForwardDiff.GradientConfig(obj, rand(25), ForwardDiff.Chunk{25}())
function ∇neglogp!(y, t, x)
ForwardDiff.gradient!(y, obj, x, gconfig)
return
end
post, ℓ, dneglogp, ∇neglogp!
end
d = TV.dimension(as(post))

post, ℓ, dneglogp, ∇neglogp! = make_grads(model_lr, At, y, σ)
# Try things out
dneglogp(2.4, randn(25), randn(25));
∇neglogp!(randn(25), 2.1, randn(25));
# Make sure gradients are working
let
ℓ, dneglogp, ∇neglogp! = make_grads(post)
@show dneglogp(2.4, randn(d), randn(d))
y = Vector{Float64}(undef, d)
@show ∇neglogp!(y, 2.1, randn(d))
nothing
end

d = 25 # number of parameters
t0 = 0.0;
x0 = zeros(d); # starting point sampler
# estimated posterior mean (n=100000, 797s)
Expand Down Expand Up @@ -129,8 +146,7 @@ sampler = ZZB.NotFactSampler(
),
);

using TupleVectors: chainvec
using Tilde.MeasureTheory: transform
# @time first(Iterators.drop(tvs,1000))

function collect_sampler(t, sampler, n; progress = true, progress_stops = 20)
if progress
Expand Down Expand Up @@ -166,7 +182,6 @@ elapsed_time = @elapsed @time begin
bps_samples, info = collect_sampler(as(post), sampler, n; progress = false)
end

using MCMCChains
bps_chain = MCMCChains.Chains(bps_samples.θ);
bps_chain = setinfo(bps_chain, (; start_time = 0.0, stop_time = elapsed_time));

Expand Down
2 changes: 1 addition & 1 deletion src/GG/deprecated_codes/explicit_scope.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function scoping(ast)
@match ast begin
:([$(frees...)]($(args...)) -> begin
$(stmts...)
end) => let stmts = map(rec, stmts), arw = :(($(args...),) -> begin
end) => let stmts = map(rec, stmts), arw = :(($(args...),) -> begin
$(stmts...)
end)
Expr(:scope, (), Tuple(frees), (), arw)
Expand Down
5 changes: 5 additions & 0 deletions src/GG/deprecated_codes/static_closure_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ function mk_closure_static(expr, toplevel::Vector{Expr})
$Closure{$glob_name,typeof(frees)}(frees)
end
)
ret = :(
let frees = $closure_arg
$Closure{$glob_name,typeof(frees)}(frees)
end
)
(fn_expr, ret)
end

Expand Down
21 changes: 13 additions & 8 deletions src/Tilde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ using Reexport: @reexport
@reexport using MeasureTheory
using MeasureBase: productmeasure, Returns

import MeasureBase: latentof, jointof, manifestof

import DensityInterface: logdensityof
import DensityInterface: densityof
import DensityInterface: DensityKind
Expand Down Expand Up @@ -79,29 +81,32 @@ include("callify.jl")
end
end

include("config.jl")
include("lensvars.jl")
include("optics.jl")
include("maybe.jl")
include("core/models/abstractmodel.jl")
include("core/models/astmodel/astmodel.jl")
include("core/models/model.jl")
include("core/dependencies.jl")
include("core/utils.jl")
include("core/models/closure.jl")
include("maybeobserved.jl")
include("core/models/posterior.jl")
include("primitives/interpret.jl")
include("distributions/iid.jl")

include("primitives/rand.jl")
include("primitives/logdensity.jl")
include("primitives/logdensity_rel.jl")
include("primitives/insupport.jl")
# include("primitives/logdensity_rel.jl")
# include("primitives/insupport.jl")

# include("primitives/basemeasure.jl")
include("primitives/testvalue.jl")
include("primitives/testparams.jl")
include("primitives/weightedsampling.jl")
# include("primitives/testparams.jl")
# include("primitives/weightedsampling.jl")
include("primitives/measures.jl")
include("primitives/basemeasure.jl")
# include("primitives/basemeasure.jl")
# include("primitives/predict.jl")
# include("primitives/dag.jl")
include("primitives/runmodel.jl")

include("transforms/utils.jl")

Expand Down
83 changes: 26 additions & 57 deletions src/callify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,82 +5,51 @@ using MLStyle

Replace every `f(args...; kwargs..)` with `mycall(f, args...; kwargs...)`
"""
function callify(mycall, ast)
function callify(g, ast)
leaf(x) = x
function branch(f, head, args)
default() = Expr(head, map(f, args)...)

# Convert `for` to `while`
if head == :for
arg1 = args[1]
@assert arg1.head == :(=)
a,A0 = arg1.args
A0 = callify(g, A0)
@gensym temp
@gensym state
@gensym A
return quote
$A = $A0
$temp = $call($g, iterate, $A)
while $temp !== nothing
$a, $state = $temp
$(args[2])
$temp = $call($g, iterate, $A, $state)
end
end
end

head == :call || return default()

if first(args) == :~ && length(args) == 3
return default()
end

# At this point we know it's a function call
length(args) == 1 && return Expr(:call, mycall, first(args))
length(args) == 1 && return Expr(:call, call, g, first(args))

fun = args[1]
arg2 = args[2]

if arg2 isa Expr && arg2.head == :parameters
# keyword arguments (try dump(:(f(x,y;a=1, b=2))) to see this)
return Expr(:call, mycall, arg2, fun, map(f, Base.rest(args, 3))...)
return Expr(:call, call, g, arg2, fun, map(f, Base.rest(args, 3))...)
else
return Expr(:call, mycall, map(f, args)...)
return Expr(:call, call, g, map(f, args)...)
end
end

foldast(leaf, branch)(ast)
foldast(leaf, branch)(ast) |> MacroTools.flatten
end

# struct Provenance{T,S}
# value::T
# sources::S
# end

# getvalue(p::Provenance) = p.value
# getvalue(x) = x

# getsources(p::Provenance) = p.sources
# getsources(x) = Set()

# function trace_provenance(f, args...; kwargs...)
# (newargs, arg_sources) = (getvalue.(args), union(getsources.(args)...))

# k = keys(kwargs)
# v = values(kwargs)
# newkwargs = NamedTuple{k}(map(getvalue, v))

# k = keys(kwargs)
# v = values(NamedTuple(kwargs))
# newkwargs = NamedTuple{k}(getvalue.(v))
# kwarg_sources = union(getsources.(args)...)

# sources = union(arg_sources, kwarg_sources)
# Provenance(f(newargs...; newkwargs), sources)
# end

# macro call(expr)
# callify(expr)
# end

# julia> callify(:(f(g(x,y))))
# :(call(f, call(g, x, y)))

# julia> callify(:(f(x; a=3)))
# :(call(f, x; a = 3))

# julia> callify(:(a+b))
# :(call(+, a, b))

# julia> callify(:(call(f,3)))
# :(call(f, 3))

# f(x) = x+1

# @call f(2)

# using SymbolicUtils

# @syms x::Vector{Float64} i::Int

# @call getindex(x,i)
2 changes: 2 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
abstract type AbstractTildeConfig end

17 changes: 15 additions & 2 deletions src/core/models/abstractmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,22 @@ N gives the Names of arguments (each a Symbol)
B gives the Body, as an Expr
M gives the Module where the model is defined
"""
abstract type AbstractModel{A,B,M} <: AbstractTransitionKernel end
abstract type AbstractModel{A,B,M,P} <: AbstractTransitionKernel end

abstract type AbstractConditionalModel{M,Args,Obs} <: AbstractMeasure end
abstract type AbstractConditionalModel{M,Args,Obs,P} <: AbstractMeasure end

# getproj(::Type{T}) where {T} = Base.Fix1(project_joint, T)
getproj(::Type{<:AbstractConditionalModel{M,Args,Obs,P}}) where {M,Args,Obs,P} = MeasureBase.instance(P)

getproj(::M) where {M<:AbstractConditionalModel} = getproj(M)

# project_joint(::Type{<:AbstractConditionalModel{M,A,O,typeof(first)}}, p) where {M,A,O} = first(p)
# project_joint(::Type{<:AbstractConditionalModel{M,A,O,typeof(last)}}, p) where {M,A,O} = last(p)
# project_joint(::Type{<:AbstractConditionalModel{M,A,O,typeof(identity)}}, p) where {M,A,O} = p

# project_joint(::AbstractConditionalModel{M,A,O,typeof(first)}, p) where {M,A,O} = first(p)
# project_joint(::AbstractConditionalModel{M,A,O,typeof(last)}, p) where {M,A,O} = last(p)
# project_joint(::AbstractConditionalModel{M,A,O,typeof(identity)}, p) where {M,A,O} = p

argstype(::AbstractModel{A,B,M}) where {A,B,M} = A

Expand Down
1 change: 1 addition & 0 deletions src/core/models/astmodel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading