Skip to content

Commit 1b8aecf

Browse files
cscherrermschauer
andauthored
0.2 Major Refactoring (#35)
* bouncy updates * rand * minor fix * Fix method ambiguity * refactoring * Disallow `rand` on ModelPosteriors * Move astmodel.jl contents into model.jl * Drop old `DAGModel` stuff * drop dead code * change `S` type param to `F` * drop on include * edit for readability * update dependencies * Update src/core/models/abstractmodel.jl * Update src/core/models/abstractmodel.jl * Update src/core/models/model.jl * drop the F * predict (#21) * predict * inbounds * fancy predict * bugfixes * bump MeasureBase version * test unbound args * cleanup * bugfix * start on refactoring PDMP example * comments, mostly * bugfix * working on `predict` * update Lens!! * Update benchmarks/bouncy.jl Co-authored-by: Moritz Schauer <[email protected]> * Update benchmarks/bouncy.jl Co-authored-by: Moritz Schauer <[email protected]> * lens stuff * drop optics change * moving things around a bit * Contexts (#32) * comment out unused LamExpr * maybeobserved.jl * refactoring * dag.jl (rough draft) * refactor * dag * comment out unused LamExpr * maybeobserved.jl * refactoring * dag.jl (rough draft) * refactor * dag * drop old call stuff * dags * formatting * formatting * bugfixes * bugfix * fix `predict` * MaybeObserved shoudl be Unobserved * Move `anyfy` * update `callify` * deps * predict WIP * latentof, etc * drop observed rand stuff * update deps * working on predict * bugfix in tests * bugfix * add a note * fix `logdensityof` * working on proj * updates * bugfix * updates * rand stuff * mucking about * refactoring * fix `rand` * readability * expand macros * inline rand * rand updates * rand * rand * drop ctx * move cfg from kwarg to "regular" arg * drop unused type parameter * help inference * update `rand` * require MeasureBase 0.12.2 * mymerge * mymerge * updates * more refactoring * Drop unneeded argument * tidying up * update * update * gg_call => runmodel * lensvars * logdensity * update interpret * interpret => runmodel * testvalue * logdensityof * measures * `model` on types * `hasreturn` * add a comment * more logdensity stuff * update rand * simplify * update TupleVectors dependency * working on `rand` * FixedRNG * AbstractTildeConfig => AbstractConfig * more rand methods for FixedRNG * AbstractTildeConfig => AbstractConfig * unbreaking things * updates * updates * bumps deps * tests passing!! * add some rand methods * make logdensity_def work * Drop `latentof` stuff * bump version * drop unneeded deps * Drop commented-out code * Newline at EOF * drop old code * tidy up * Newline at EOF * Newline at EOF * bugfix * bugfix * drop two `anyfy` calls * combine two lines * simplify * more refactoring * updates * drop old code * drop old code * insupport * undo `bouncy` update (let's do that separately) * drop benchmarks `Project.toml` update for now * formatting * add tests * Compat.jl (for `allequal`) * Need MeasureTheory 0.17.2 * drop old `allequal` * May not actually need Compat here * dropRuntimeGeneratedFunctions Co-authored-by: Moritz Schauer <[email protected]>
1 parent 49c3ddc commit 1b8aecf

31 files changed

+1185
-778
lines changed

Project.toml

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,29 @@
11
name = "Tilde"
22
uuid = "73a6ac3c-4b34-4cca-a813-308f7589d80d"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.1.2"
4+
version = "0.2.0"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
8-
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1010
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
11-
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
11+
DistributionMeasures = "35643b39-bfd4-4670-843f-16596ca89bf3"
1212
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1313
JuliaVariables = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
1616
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
17-
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
1817
MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14"
1918
MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
2019
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
2120
NestedTuples = "a734d2a7-8d68-409b-9419-626914d4061d"
22-
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2321
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
24-
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
2522
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2623
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
27-
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
2824
SampleChains = "754583d1-7fc4-4dab-93b5-5eaca5c9622e"
2925
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
30-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3126
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
32-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3327
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3428
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
3529
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
@@ -38,32 +32,28 @@ TupleVectors = "615932cf-77b6-4358-adcd-5b7eba981d7e"
3832

3933
[compat]
4034
Accessors = "0.1"
41-
ArrayInterface = "4, 5, 6"
35+
ChainRulesCore = "1"
4236
DataStructures = "0.18"
4337
DensityInterface = "0.4"
44-
DiffResults = "1"
38+
DistributionMeasures = "0.2"
4539
IfElse = "0.1"
4640
JuliaVariables = "0.2"
4741
MLStyle = "0.3,0.4"
4842
MacroTools = "0.5"
49-
MappedArrays = "0.3, 0.4"
50-
MeasureBase = "0.9"
51-
MeasureTheory = "0.16"
43+
MeasureBase = "0.13"
44+
MeasureTheory = "0.17.2"
5245
NamedTupleTools = "0.12, 0.13, 0.14"
5346
NestedTuples = "0.3"
54-
RecipesBase = "1"
5547
Reexport = "1"
5648
Requires = "1"
57-
RuntimeGeneratedFunctions = "0.5"
5849
SampleChains = "0.5"
59-
SpecialFunctions = "1, 2"
6050
Static = "0.5, 0.6"
6151
StatsBase = "0.33"
6252
StatsFuns = "0.9, 1"
6353
TransformVariables = "0.5, 0.6"
6454
Tricks = "0.1"
65-
TupleVectors = "0.1"
66-
julia = "1.5"
55+
TupleVectors = "0.1, 0.2"
56+
julia = "1.6"
6757

6858
[extras]
6959
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"

src/GG/deprecated_codes/explicit_scope.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ function scoping(ast)
33
@match ast begin
44
:([$(frees...)]($(args...)) -> begin
55
$(stmts...)
6-
end) => let stmts = map(rec, stmts), arw = :(($(args...),) -> begin
6+
end) => let stmts = map(rec, stmts), arw = :(($(args...),) -> begin
77
$(stmts...)
88
end)
99
Expr(:scope, (), Tuple(frees), (), arw)

src/GG/deprecated_codes/static_closure_conv.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ function mk_closure_static(expr, toplevel::Vector{Expr})
3939
$Closure{$glob_name,typeof(frees)}(frees)
4040
end
4141
)
42+
ret = :(
43+
let frees = $closure_arg
44+
$Closure{$glob_name,typeof(frees)}(frees)
45+
end
46+
)
4247
(fn_expr, ret)
4348
end
4449

src/Tilde.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import DensityInterface: densityof
1313
import DensityInterface: DensityKind
1414
using DensityInterface
1515

16+
using DistributionMeasures
1617
using NamedTupleTools
1718
using SampleChains
1819
# using SymbolicCodegen
@@ -30,7 +31,6 @@ import MLStyle
3031
# using MonteCarloMeasurements: Particles, StaticParticles, AbstractParticles
3132

3233
using Requires
33-
using ArrayInterface: StaticInt
3434
using Static
3535

3636
using IfElse: ifelse
@@ -42,8 +42,6 @@ using TupleVectors: unwrap
4242
# using SimplePosets: SimplePoset
4343
# import SimplePosets
4444

45-
using RuntimeGeneratedFunctions
46-
RuntimeGeneratedFunctions.init(@__MODULE__)
4745
using MeasureBase: AbstractTransitionKernel
4846

4947
using NestedTuples: TypelevelExpr
@@ -79,29 +77,33 @@ include("callify.jl")
7977
end
8078
end
8179

80+
include("fixedrng.jl")
81+
include("config.jl")
82+
include("lensvars.jl")
8283
include("optics.jl")
8384
include("maybe.jl")
8485
include("core/models/abstractmodel.jl")
85-
include("core/models/astmodel/astmodel.jl")
8686
include("core/models/model.jl")
8787
include("core/dependencies.jl")
8888
include("core/utils.jl")
8989
include("core/models/closure.jl")
90+
include("maybeobserved.jl")
9091
include("core/models/posterior.jl")
91-
include("primitives/interpret.jl")
9292
include("distributions/iid.jl")
9393

9494
include("primitives/rand.jl")
9595
include("primitives/logdensity.jl")
96-
include("primitives/logdensity_rel.jl")
96+
# include("primitives/logdensity_rel.jl")
9797
include("primitives/insupport.jl")
9898

99-
# include("primitives/basemeasure.jl")
10099
include("primitives/testvalue.jl")
101-
include("primitives/testparams.jl")
102-
include("primitives/weightedsampling.jl")
100+
# include("primitives/testparams.jl")
101+
# include("primitives/weightedsampling.jl")
103102
include("primitives/measures.jl")
104103
include("primitives/basemeasure.jl")
104+
include("primitives/predict.jl")
105+
# include("primitives/dag.jl")
106+
include("primitives/runmodel.jl")
105107

106108
include("transforms/utils.jl")
107109

@@ -111,4 +113,9 @@ function __init__()
111113
end
112114
end
113115

116+
Base.copy(m::AbstractModel) = m
117+
Base.copy(cl::ModelClosure) = model(cl)(rmap(copy, argvals(cl)))
118+
119+
Base.copy(post::ModelPosterior) = copy(post.closure) | rmap(copy, observations(post))
120+
114121
end # module

src/callify.jl

Lines changed: 26 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,82 +5,50 @@ using MLStyle
55
66
Replace every `f(args...; kwargs..)` with `mycall(f, args...; kwargs...)`
77
"""
8-
function callify(mycall, ast)
8+
function callify(g, ast)
99
leaf(x) = x
1010
function branch(f, head, args)
1111
default() = Expr(head, map(f, args)...)
12+
13+
# Convert `for` to `while`
14+
if head == :for
15+
arg1 = args[1]
16+
@assert arg1.head == :(=)
17+
a, A0 = arg1.args
18+
A0 = callify(g, A0)
19+
@gensym temp
20+
@gensym state
21+
@gensym A
22+
return quote
23+
$A = $A0
24+
$temp = $call($g, iterate, $A)
25+
while $temp !== nothing
26+
$a, $state = $temp
27+
$(args[2])
28+
$temp = $call($g, iterate, $A, $state)
29+
end
30+
end
31+
end
32+
1233
head == :call || return default()
1334

1435
if first(args) == :~ && length(args) == 3
1536
return default()
1637
end
1738

1839
# At this point we know it's a function call
19-
length(args) == 1 && return Expr(:call, mycall, first(args))
40+
length(args) == 1 && return Expr(:call, call, g, first(args))
2041

2142
fun = args[1]
2243
arg2 = args[2]
2344

2445
if arg2 isa Expr && arg2.head == :parameters
2546
# keyword arguments (try dump(:(f(x,y;a=1, b=2))) to see this)
26-
return Expr(:call, mycall, arg2, fun, map(f, Base.rest(args, 3))...)
47+
return Expr(:call, call, g, arg2, fun, map(f, Base.rest(args, 3))...)
2748
else
28-
return Expr(:call, mycall, map(f, args)...)
49+
return Expr(:call, call, g, map(f, args)...)
2950
end
3051
end
3152

32-
foldast(leaf, branch)(ast)
53+
foldast(leaf, branch)(ast) |> MacroTools.flatten
3354
end
34-
35-
# struct Provenance{T,S}
36-
# value::T
37-
# sources::S
38-
# end
39-
40-
# getvalue(p::Provenance) = p.value
41-
# getvalue(x) = x
42-
43-
# getsources(p::Provenance) = p.sources
44-
# getsources(x) = Set()
45-
46-
# function trace_provenance(f, args...; kwargs...)
47-
# (newargs, arg_sources) = (getvalue.(args), union(getsources.(args)...))
48-
49-
# k = keys(kwargs)
50-
# v = values(kwargs)
51-
# newkwargs = NamedTuple{k}(map(getvalue, v))
52-
53-
# k = keys(kwargs)
54-
# v = values(NamedTuple(kwargs))
55-
# newkwargs = NamedTuple{k}(getvalue.(v))
56-
# kwarg_sources = union(getsources.(args)...)
57-
58-
# sources = union(arg_sources, kwarg_sources)
59-
# Provenance(f(newargs...; newkwargs), sources)
60-
# end
61-
62-
# macro call(expr)
63-
# callify(expr)
64-
# end
65-
66-
# julia> callify(:(f(g(x,y))))
67-
# :(call(f, call(g, x, y)))
68-
69-
# julia> callify(:(f(x; a=3)))
70-
# :(call(f, x; a = 3))
71-
72-
# julia> callify(:(a+b))
73-
# :(call(+, a, b))
74-
75-
# julia> callify(:(call(f,3)))
76-
# :(call(f, 3))
77-
78-
# f(x) = x+1
79-
80-
# @call f(2)
81-
82-
# using SymbolicUtils
83-
84-
# @syms x::Vector{Float64} i::Int
85-
86-
# @call getindex(x,i)

src/config.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
abstract type AbstractConfig end

src/core/models/astmodel.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

src/core/models/astmodel/astmodel.jl

Lines changed: 0 additions & 82 deletions
This file was deleted.

0 commit comments

Comments
 (0)