Skip to content

ProbProg: Making trace an operand #1444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 90 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
902ced9
generate
sbrantq May 2, 2025
e2c77e4
refactor
sbrantq May 2, 2025
e204d13
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 2, 2025
327b10a
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 6, 2025
d611ae4
add probprog pass to :all
sbrantq May 7, 2025
3672d83
improve test
sbrantq May 7, 2025
b70843e
only probprog opt mode
sbrantq May 8, 2025
597fa89
fix up test
sbrantq May 8, 2025
e6c2c0a
move
sbrantq May 12, 2025
9b9395e
simplify
sbrantq May 12, 2025
b3ba477
fix up
sbrantq May 14, 2025
47e9fe3
saving changes
sbrantq May 15, 2025
982b2bf
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 16, 2025
06b7464
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 17, 2025
bd73c62
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 19, 2025
a6fcca3
fix sample op
sbrantq May 20, 2025
e51e04b
save tests
sbrantq May 20, 2025
ce68f6a
temporarily removing probprog pass from :all as MLIR pass is not merg…
sbrantq May 20, 2025
94bbe62
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 20, 2025
d31bba6
undo enzyme binding change
sbrantq May 20, 2025
573fa02
format
sbrantq May 20, 2025
0264a3d
format
sbrantq May 20, 2025
2e18bdf
improve
sbrantq May 20, 2025
1f19979
improve
sbrantq May 20, 2025
096d790
get rid of result_and_mutated too
sbrantq May 22, 2025
bb319a3
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 31, 2025
9ac6535
working trace object pointer hacks + tests
sbrantq Jun 5, 2025
b24766f
Assuming scalar samples for now; simple Bayesian linear regression test
sbrantq Jun 5, 2025
3c52b39
exclamation mark
sbrantq Jun 5, 2025
af3d055
sample metadata
sbrantq Jun 6, 2025
6c7ffa3
fix up copy
sbrantq Jun 6, 2025
4e017d0
fix up copy
sbrantq Jun 6, 2025
e53fc7c
working vectorized blr test
sbrantq Jun 6, 2025
1dbf5c7
fix test warning
sbrantq Jun 11, 2025
dd9dcab
hacks to temporarily remove world age issue in tests
sbrantq Jun 11, 2025
a344726
partial refactoring
sbrantq Jun 12, 2025
ebeceb8
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jun 13, 2025
ef2e770
fixed tracing infra
sbrantq Jun 14, 2025
46e0f6b
transpose fix up
sbrantq Jun 16, 2025
1c5297c
minor changes
sbrantq Jun 17, 2025
d707053
reorder
sbrantq Jun 17, 2025
91a0850
API change
sbrantq Jun 20, 2025
561b051
better print
sbrantq Jun 20, 2025
99d7608
unconstrained real generate op
sbrantq Jun 25, 2025
b13f8bf
probprog postpasses
sbrantq Jun 25, 2025
6e4dc0c
bug fix for alising outputs
sbrantq Jun 26, 2025
5b5c1d1
generate op with constraints
sbrantq Jun 26, 2025
1ad167a
untraced call
sbrantq Jun 26, 2025
8f66b5f
working metropolis hastings (with hacks)
sbrantq Jun 26, 2025
850e3c4
set julia rng
sbrantq Jun 27, 2025
e1b3bcb
remove print
sbrantq Jun 27, 2025
659b963
less iterations. hiding prints
sbrantq Jun 27, 2025
537de49
add probprog test group
sbrantq Jun 27, 2025
04d2e44
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jun 27, 2025
8260fee
format
sbrantq Jun 27, 2025
0f94166
add probprog compile opt
sbrantq Jun 27, 2025
7f611fe
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jul 3, 2025
a05d2c2
pass all args even when w/o rng
sbrantq Jul 3, 2025
f40960f
updated probprog frontend for refactored simulate op
sbrantq Jul 3, 2025
f6ee849
probprog attr mlir api
sbrantq Jul 3, 2025
38e33de
adding cfunction mapping for AddWeightToTrace and AddRetvalToTrace ops
sbrantq Jul 4, 2025
127126d
adding traced_output_indices attr to simulate op
sbrantq Jul 4, 2025
3d66c7a
update tests
sbrantq Jul 4, 2025
1585483
refactored generate op
sbrantq Jul 8, 2025
34f35c4
@compile for generate op
sbrantq Jul 8, 2025
f4a6415
improve api
sbrantq Jul 8, 2025
b92a733
compiled generate test
sbrantq Jul 8, 2025
160561e
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jul 12, 2025
bbfa3f6
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jul 14, 2025
f4c4a88
save gc change
sbrantq Jul 15, 2025
d1be27c
enforcing calling convention (rng being the 0th operand) for sample &…
sbrantq Jul 16, 2025
b666813
enforcing calling convention (rng being 0th operand) for simulate/gen…
sbrantq Jul 17, 2025
c57a1e4
clean up
sbrantq Jul 18, 2025
2b81db9
refactored mh inference steps with new calling convention enforced
sbrantq Jul 18, 2025
e647b0d
improve
sbrantq Jul 20, 2025
4fe55a6
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jul 20, 2025
94b9e3a
reorganize
sbrantq Jul 20, 2025
b2d583a
format
sbrantq Jul 20, 2025
65d3595
fix up tests
sbrantq Jul 21, 2025
a29fbed
remove redundant cast
sbrantq Jul 21, 2025
87ced72
generate op fixup: replacing constrained_symbols with constrained_add…
sbrantq Jul 29, 2025
ebec467
minor
sbrantq Jul 29, 2025
f771bcb
update legacy inference API
sbrantq Jul 29, 2025
1908188
simplify
sbrantq Jul 29, 2025
0b71444
cleanup
sbrantq Jul 29, 2025
1a23c2e
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jul 29, 2025
9bd1dee
fix deadlock
sbrantq Jul 31, 2025
c9ff7c0
fix test
sbrantq Jul 31, 2025
3196989
don't print
sbrantq Jul 31, 2025
4afda71
clean up postpasses
sbrantq Jul 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,21 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) {
(mlir::enzyme::Activity)val));
}

extern "C" MLIR_CAPI_EXPORTED MlirType enzymeTraceTypeGet(MlirContext ctx) {
return wrap(mlir::enzyme::TraceType::get(unwrap(ctx)));
}

extern "C" MLIR_CAPI_EXPORTED MlirType
enzymeConstraintTypeGet(MlirContext ctx) {
return wrap(mlir::enzyme::ConstraintType::get(unwrap(ctx)));
}

extern "C" MLIR_CAPI_EXPORTED MlirAttribute
enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) {
mlir::Attribute attr = mlir::enzyme::SymbolAttr::get(unwrap(ctx), symbol);
return wrap(attr);
}

// Create profiler session and start profiling
extern "C" tsl::ProfilerSession *
CreateProfilerSession(uint32_t device_tracer_level,
Expand Down
2 changes: 2 additions & 0 deletions src/CompileOptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ function CompileOptions(;
:canonicalize,
:just_batch,
:none,
:probprog,
:probprog_no_lowering,
]
end

Expand Down
118 changes: 118 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,7 @@ end
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize\"}"

function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true)
pm = MLIR.IR.PassManager()
Expand Down Expand Up @@ -1641,6 +1642,7 @@ function compile_mlir!(
blas_int_width = sizeof(BLAS.BlasInt) * 8
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
blas_int_width=$blas_int_width}"
lower_enzyme_probprog_pass = "lower-enzyme-probprog{backend=$backend}"

legalize_chlo_to_stablehlo =
if legalize_stablehlo_to_mhlo || compile_options.legalize_chlo_to_stablehlo
Expand Down Expand Up @@ -1807,6 +1809,122 @@ function compile_mlir!(
),
"no_enzyme",
)
elseif compile_options.optimization_passes === :probprog_no_lowering
run_pass_pipeline!(
mod,
join(
if compile_options.raise_first
[
"mark-func-memory-effects",
opt_passes,
kern,
raise_passes,
"enzyme-batch",
opt_passes2,
enzyme_pass,
probprog_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
]
else
[
"mark-func-memory-effects",
opt_passes,
"enzyme-batch",
opt_passes2,
enzyme_pass,
probprog_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
kern,
raise_passes,
]
end,
",",
),
"probprog_no_lowering",
)
elseif compile_options.optimization_passes === :probprog
run_pass_pipeline!(
mod,
join(
if compile_options.raise_first
[
"mark-func-memory-effects",
opt_passes,
kern,
raise_passes,
"enzyme-batch",
opt_passes2,
enzyme_pass,
probprog_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
lower_enzymexla_linalg_pass,
lower_enzyme_probprog_pass,
jit,
]
else
[
"mark-func-memory-effects",
opt_passes,
"enzyme-batch",
opt_passes2,
enzyme_pass,
probprog_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
kern,
raise_passes,
lower_enzymexla_linalg_pass,
lower_enzyme_probprog_pass,
jit,
]
end,
",",
),
"probprog",
)
elseif compile_options.optimization_passes === :only_enzyme
run_pass_pipeline!(
mod,
Expand Down
1 change: 1 addition & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ include("Tracing.jl")
include("Compiler.jl")

include("Overlay.jl")
include("probprog/ProbProg.jl")

# Serialization
include("serialization/Serialization.jl")
Expand Down
2 changes: 2 additions & 0 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ function ConcretePJRTArray(
end

Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data)
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data)
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber})
x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data))
Expand Down Expand Up @@ -412,6 +413,7 @@ function ConcreteIFRTArray(
end

Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data)
Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data)
XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data)
function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber})
return XLA.device(x.data)
Expand Down
87 changes: 87 additions & 0 deletions src/probprog/Display.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104
function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
VERT = '\u2502'
PLUS = '\u251C'
HORZ = '\u2500'
LAST = '\u2514'

indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' '])
indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' '])

for i in vert_bars
indent_vert[i] = VERT
indent[i] = VERT
indent_last[i] = VERT
end

indent_vert_str = join(indent_vert)
indent_str = join(indent)
indent_last_str = join(indent_last)

sorted_choices = sort(collect(trace.choices); by=x -> x[1])
n = length(sorted_choices)

if trace.retval !== nothing
n += 1
end

if trace.weight !== nothing
n += 1
end

cur = 1

if trace.retval !== nothing
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n")
cur += 1
end

if trace.weight !== nothing
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n")
cur += 1
end

for (key, value) in sorted_choices
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n")
cur += 1
end

sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1])
n += length(sorted_subtraces)

for (key, subtrace) in sorted_subtraces
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n")
_show_pretty(
io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1)
)
cur += 1
end
end

function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace)
println(io, "ProbProgTrace:")
if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing
println(io, " (empty)")
else
_show_pretty(io, trace, 0, ())
end
end

function Base.show(io::IO, trace::ProbProgTrace)
if get(io, :compact, false)
choices_count = length(trace.choices)
has_retval = trace.retval !== nothing
print(io, "ProbProgTrace($(choices_count) choices")
if has_retval
print(io, ", retval=$(trace.retval), weight=$(trace.weight)")
end
print(io, ")")
else
show(io, MIME"text/plain"(), trace)
end
end
Loading
Loading