Skip to content

Bouncy #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 139 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
d77eaa2
bouncy updates
cscherrer Jun 7, 2022
8fff1cc
rand
cscherrer Jun 11, 2022
059ab30
minor fix
cscherrer Jun 11, 2022
4d65f63
Fix method ambiguity
cscherrer Jun 11, 2022
2dd7924
refactoring
cscherrer Jun 17, 2022
939f341
Merge branch 'dev' into temp
cscherrer Jun 17, 2022
78c3df1
Merge pull request #25 from cscherrer/temp
cscherrer Jun 17, 2022
97c29c5
Merge remote-tracking branch 'origin/dev' into dev
cscherrer Jun 17, 2022
52179a9
Merge pull request #20 from cscherrer/rand
cscherrer Jun 21, 2022
dc3c376
Disallow `rand` on ModelPosteriors
cscherrer Jun 21, 2022
3c03bd4
Merge branch 'dev' of https://github.com/cscherrer/Tilde.jl into dev
cscherrer Jun 21, 2022
a352754
Move astmodel.jl contents into model.jl
cscherrer Jun 21, 2022
3921a40
Drop old `DAGModel` stuff
cscherrer Jun 21, 2022
3aaaa74
drop dead code
cscherrer Jun 21, 2022
2f832a5
change `S` type param to `F`
cscherrer Jun 21, 2022
69ce0d8
drop on include
cscherrer Jun 21, 2022
36a2e33
edit for readability
cscherrer Jun 22, 2022
781c29d
update dependencies
cscherrer Jun 22, 2022
fbef2d1
Update src/core/models/abstractmodel.jl
cscherrer Jun 22, 2022
d41b73a
Update src/core/models/abstractmodel.jl
cscherrer Jun 22, 2022
cf209dc
Update src/core/models/model.jl
cscherrer Jun 22, 2022
bedd3fa
drop the F
cscherrer Jun 22, 2022
564ec84
Merge branch 'dev' of https://github.com/cscherrer/Tilde.jl into dev
cscherrer Jun 22, 2022
40933c3
predict (#21)
cscherrer Jun 22, 2022
9eefb93
test unbound args
cscherrer Jun 23, 2022
5e12f9a
cleanup
cscherrer Jun 23, 2022
2d29d0d
bugfix
cscherrer Jun 23, 2022
07dcf28
start on refactoring PDMP example
cscherrer Jun 24, 2022
c4ce48e
comments, mostly
cscherrer Jun 24, 2022
4011b4a
bugfix
cscherrer Jun 25, 2022
dd740bd
working on `predict`
cscherrer Jun 27, 2022
afb72fc
update Lens!!
cscherrer Jun 27, 2022
3d35531
Update benchmarks/bouncy.jl
cscherrer Jun 27, 2022
39fabf0
Update benchmarks/bouncy.jl
cscherrer Jun 27, 2022
761c60e
Merge branch 'main' into dev
cscherrer Jun 27, 2022
1d81bf7
lens stuff
cscherrer Jun 27, 2022
5c2ed88
drop optics change
cscherrer Jun 27, 2022
2ce7f40
moving things around a bit
cscherrer Jun 27, 2022
a0aabb8
Contexts (#32)
cscherrer Jul 4, 2022
e6cb833
Merge branch 'main' of https://github.com/cscherrer/Tilde.jl into dev
cscherrer Jul 4, 2022
6513e51
formatting
cscherrer Jul 4, 2022
75bdc37
bugfixes
cscherrer Jul 5, 2022
d8fb3cc
bugfix
cscherrer Jul 5, 2022
fd6ce5d
fix `predict`
cscherrer Jul 5, 2022
f315d9b
MaybeObserved shoudl be Unobserved
cscherrer Jul 5, 2022
7c93602
Move `anyfy`
cscherrer Jul 6, 2022
2181564
update `callify`
cscherrer Jul 6, 2022
ee2ab1a
deps
cscherrer Jul 6, 2022
b4da1e6
predict WIP
cscherrer Jul 7, 2022
5f1bf3c
latentof, etc
cscherrer Jul 11, 2022
98f5d05
drop observed rand stuff
cscherrer Jul 11, 2022
5e819a3
update deps
cscherrer Jul 11, 2022
4cbd74f
working on predict
cscherrer Jul 11, 2022
f57f397
bugfix in tests
cscherrer Jul 11, 2022
303ab5b
bugfix
cscherrer Jul 11, 2022
08d2b0e
add a note
cscherrer Jul 11, 2022
c6136ab
fix `logdensityof`
cscherrer Jul 11, 2022
9ab1c53
working on proj
cscherrer Jul 11, 2022
411f08e
updates
cscherrer Jul 11, 2022
7ac3f85
bugfix
cscherrer Jul 11, 2022
bc11132
updates
cscherrer Jul 12, 2022
e69d41b
rand stuff
cscherrer Jul 12, 2022
00a5a03
mucking about
cscherrer Jul 12, 2022
d25f4e4
refactoring
cscherrer Jul 13, 2022
441d627
fix `rand`
cscherrer Jul 13, 2022
a1bbc84
readability
cscherrer Jul 13, 2022
14c67dd
expand macros
cscherrer Jul 14, 2022
ab953b1
inline rand
cscherrer Jul 14, 2022
3d7b69c
rand updates
cscherrer Jul 14, 2022
bba9786
rand
cscherrer Jul 14, 2022
b887d24
rand
cscherrer Jul 14, 2022
5fc4eb3
drop ctx
cscherrer Jul 14, 2022
e9d55a2
move cfg from kwarg to "regular" arg
cscherrer Jul 14, 2022
8ece68d
drop unused type parameter
cscherrer Jul 14, 2022
4e8d102
help inference
cscherrer Jul 14, 2022
96d3ccf
update `rand`
cscherrer Jul 15, 2022
3321d5e
require MeasureBase 0.12.2
cscherrer Jul 17, 2022
cfcf575
mymerge
cscherrer Jul 18, 2022
459a908
mymerge
cscherrer Jul 18, 2022
570e082
updates
cscherrer Jul 18, 2022
39050eb
Merge branch 'dev-predict' of https://github.com/cscherrer/Tilde.jl i…
cscherrer Jul 18, 2022
3a2141d
more refactoring
cscherrer Jul 18, 2022
4008ba8
Drop unneeded argument
cscherrer Jul 18, 2022
662d9a4
tidying up
cscherrer Jul 18, 2022
0194a22
update
cscherrer Jul 18, 2022
e0f9e6c
update
cscherrer Jul 19, 2022
f09b6a3
Merge branch 'dev-predict' of github.com:cscherrer/Tilde.jl into dev-…
cscherrer Jul 19, 2022
d5889eb
gg_call => runmodel
cscherrer Jul 19, 2022
9f56597
lensvars
cscherrer Jul 19, 2022
3b7d579
logdensity
cscherrer Jul 19, 2022
8d7d32c
update interpret
cscherrer Jul 19, 2022
c74ddea
interpret => runmodel
cscherrer Jul 19, 2022
bb2eff8
testvalue
cscherrer Jul 19, 2022
c8fe8fa
logdensityof
cscherrer Jul 19, 2022
b6dd098
measures
cscherrer Jul 19, 2022
3468049
`model` on types
cscherrer Jul 20, 2022
e0ef186
`hasreturn`
cscherrer Jul 20, 2022
0c7c3ce
add a comment
cscherrer Jul 20, 2022
bb4ec23
more logdensity stuff
cscherrer Jul 20, 2022
3b2df09
update rand
cscherrer Jul 28, 2022
4d45693
simplify
cscherrer Jul 28, 2022
59dadf3
update TupleVectors dependency
cscherrer Jul 29, 2022
29ce089
working on `rand`
cscherrer Jul 29, 2022
2681c77
FixedRNG
cscherrer Jul 29, 2022
b005b60
AbstractTildeConfig => AbstractConfig
cscherrer Jul 31, 2022
32e81f0
more rand methods for FixedRNG
cscherrer Aug 1, 2022
27c242e
AbstractTildeConfig => AbstractConfig
cscherrer Aug 1, 2022
10b2638
Merge branch 'simple-return' of https://github.com/cscherrer/Tilde.jl…
cscherrer Aug 1, 2022
e99b453
unbreaking things
cscherrer Aug 4, 2022
e437348
updates
cscherrer Aug 4, 2022
8ad32e7
updates
cscherrer Aug 4, 2022
30aecc1
bumps deps
cscherrer Aug 5, 2022
3fb430b
tests passing!!
cscherrer Aug 10, 2022
fe426e8
add some rand methods
cscherrer Aug 10, 2022
802aae4
make logdensity_def work
cscherrer Aug 10, 2022
962eba7
Drop `latentof` stuff
cscherrer Aug 10, 2022
5e52854
bump version
cscherrer Aug 10, 2022
b3dc2c1
drop unneeded deps
cscherrer Aug 10, 2022
0571975
Drop commented-out code
cscherrer Aug 10, 2022
d0b56e1
Newline at EOF
cscherrer Aug 11, 2022
8b11368
drop old code
cscherrer Aug 11, 2022
bdcdd13
tidy up
cscherrer Aug 11, 2022
521c461
Newline at EOF
cscherrer Aug 11, 2022
fbe0846
Newline at EOF
cscherrer Aug 11, 2022
696685c
bugfix
cscherrer Aug 11, 2022
19441e3
bugfix
cscherrer Aug 12, 2022
a242f42
drop two `anyfy` calls
cscherrer Aug 12, 2022
fa18357
combine two lines
cscherrer Aug 12, 2022
2dcc0a1
simplify
cscherrer Aug 12, 2022
6ba41dd
more refactoring
cscherrer Aug 12, 2022
02081ef
updates
cscherrer Aug 14, 2022
4e9622c
drop old code
cscherrer Aug 14, 2022
41b0b06
drop old code
cscherrer Aug 14, 2022
d1e8319
insupport
cscherrer Aug 14, 2022
b006353
undo `bouncy` update (let's do that separately)
cscherrer Aug 14, 2022
473b4ed
update `bouncy`
cscherrer Aug 14, 2022
8bf6b9e
Merge branch 'main' into bouncy
cscherrer Aug 17, 2022
a3209ed
fix merge error
cscherrer Aug 17, 2022
2b5d3a7
fix whiespace
cscherrer Aug 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions benchmarks/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Expand Down
76 changes: 47 additions & 29 deletions benchmarks/bouncy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,76 @@ using ForwardDiff
using ForwardDiff: Dual
using Pathfinder
using Pathfinder.PDMats
using MCMCChains
using TupleVectors: chainvec
using Tilde.MeasureTheory: transform

Random.seed!(1)

function make_grads(post)
as_post = as(post)
d = TV.dimension(as_post)
tr(θ) = transform(as_post, θ)
obj(θ) = -Tilde.unsafe_logdensityof(post, tr(θ))
ℓ(θ) = -obj(θ)
@inline function dneglogp(t, x, v, args...) # two directional derivatives
f(t) = obj(x + t * v)
u = ForwardDiff.derivative(f, Dual{:hSrkahPmmC}(0.0, 1.0))
u.value, u.partials[]
end

gconfig = ForwardDiff.GradientConfig(obj, rand(d), ForwardDiff.Chunk{d}())
function ∇neglogp!(y, t, x, args...)
ForwardDiff.gradient!(y, obj, x, gconfig)
y
end
ℓ, dneglogp, ∇neglogp!
end

# ↑ general purpose
############################################################
# ↓ problem-specific

# read data
function readlrdata()
fname = joinpath("lr.data")
z = readdlm(fname)
A = z[:, 1:end-1]
A = z[:, 1:(end-1)]
A = [ones(size(A, 1)) A]
y = z[:, end] .- 1
return A, y
end
A, y = readlrdata();
At = collect(A');

model_lr = @model (At, y, σ) begin
d, n = size(At)
θ ~ Normal(σ = σ)^d
for j in 1:n
logitp = dot(view(At, :, j), θ)
logitp = view(At, :, j)' * θ
y[j] ~ Bernoulli(logitp = logitp)
end
end

# Define model arguments
A, y = readlrdata();
At = collect(A');
σ = 100.0

function make_grads(model_lr, At, y, σ)
post = model_lr(At, y, σ) | (; y)
as_post = as(post)
obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ))
ℓ(θ) = -obj(θ)
@inline function dneglogp(t, x, v) # two directional derivatives
f(t) = obj(x + t * v)
u = ForwardDiff.derivative(f, Dual{:hSrkahPmmC}(0.0, 1.0))
u.value, u.partials[]
end
# Represent the posterior
post = model_lr(At, y, σ) | (; y)

gconfig = ForwardDiff.GradientConfig(obj, rand(25), ForwardDiff.Chunk{25}())
function ∇neglogp!(y, t, x)
ForwardDiff.gradient!(y, obj, x, gconfig)
return
end
post, ℓ, dneglogp, ∇neglogp!
end
d = TV.dimension(as(post))


ℓ, dneglogp, ∇neglogp! = make_grads(post)

post, ℓ, dneglogp, ∇neglogp! = make_grads(model_lr, At, y, σ)
# Try things out
dneglogp(2.4, randn(25), randn(25));
∇neglogp!(randn(25), 2.1, randn(25));
# Make sure gradients are working
let
@show dneglogp(2.4, randn(d), randn(d))
y = Vector{Float64}(undef, d)
@show ∇neglogp!(y, 2.1, randn(d))
nothing
end

d = 25 # number of parameters
t0 = 0.0;
x0 = zeros(d); # starting point sampler
# estimated posterior mean (n=100000, 797s)
Expand Down Expand Up @@ -129,8 +149,7 @@ sampler = ZZB.NotFactSampler(
),
);

using TupleVectors: chainvec
using Tilde.MeasureTheory: transform
# @time first(Iterators.drop(tvs,1000))

function collect_sampler(t, sampler, n; progress = true, progress_stops = 20)
if progress
Expand Down Expand Up @@ -166,7 +185,6 @@ elapsed_time = @elapsed @time begin
bps_samples, info = collect_sampler(as(post), sampler, n; progress = false)
end

using MCMCChains
bps_chain = MCMCChains.Chains(bps_samples.θ);
bps_chain = setinfo(bps_chain, (; start_time = 0.0, stop_time = elapsed_time));

Expand Down