From 3399b2334c9eadf5a310edb5fd41eaa6e61a23a4 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sun, 26 Oct 2025 16:50:30 -0500 Subject: [PATCH 01/10] static HMC --- src/CompileOptions.jl | 1 + src/Compiler.jl | 66 ++++ src/Reactant.jl | 1 + src/Types.jl | 2 + src/probprog/Display.jl | 87 +++++ src/probprog/FFI.jl | 768 ++++++++++++++++++++++++++++++++++++++ src/probprog/HMC.jl | 126 +++++++ src/probprog/MH.jl | 95 +++++ src/probprog/Modeling.jl | 257 +++++++++++++ src/probprog/ProbProg.jl | 28 ++ src/probprog/Types.jl | 77 ++++ src/probprog/Utils.jl | 154 ++++++++ test/probprog/generate.jl | 142 +++++++ test/probprog/hmc.jl | 138 +++++++ test/probprog/mh.jl | 114 ++++++ test/probprog/sample.jl | 88 +++++ test/probprog/simulate.jl | 115 ++++++ test/runtests.jl | 7 + 18 files changed, 2266 insertions(+) create mode 100644 src/probprog/Display.jl create mode 100644 src/probprog/FFI.jl create mode 100644 src/probprog/HMC.jl create mode 100644 src/probprog/MH.jl create mode 100644 src/probprog/Modeling.jl create mode 100644 src/probprog/ProbProg.jl create mode 100644 src/probprog/Types.jl create mode 100644 src/probprog/Utils.jl create mode 100644 test/probprog/generate.jl create mode 100644 test/probprog/hmc.jl create mode 100644 test/probprog/mh.jl create mode 100644 test/probprog/sample.jl create mode 100644 test/probprog/simulate.jl diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index e8cac78be6..925c357e1a 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -229,6 +229,7 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :probprog, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index bea2f19472..9a26888880 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1318,6 +1318,7 @@ end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize,arith-raise{stablehlo=true}\"}" +const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true}\"}" function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true) pm = MLIR.IR.PassManager() @@ -1905,6 +1906,71 @@ function compile_mlir!( ), "no_enzyme", ) + elseif compile_options.optimization_passes === :probprog + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + probprog_pass, + "lower-probprog-to-stablehlo{backend=$backend}", + "outline-enzyme-regions", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + lower_enzymexla_linalg_pass, + "lower-probprog-trace-ops{backend=$backend}", + jit, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + probprog_pass, + "lower-probprog-to-stablehlo{backend=$backend}", + "outline-enzyme-regions", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + kern, + raise_passes, + lower_enzymexla_linalg_pass, + "lower-probprog-trace-ops{backend=$backend}", + jit, + ] + end, + ",", + ), + "probprog", + ) elseif compile_options.optimization_passes === :only_enzyme run_pass_pipeline!( mod, diff --git a/src/Reactant.jl b/src/Reactant.jl index 69cee2b3b0..0ef9af5145 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -246,6 +246,7 @@ include("Tracing.jl") include("Compiler.jl") include("Overlay.jl") +include("probprog/ProbProg.jl") # Serialization include("serialization/Serialization.jl") diff --git a/src/Types.jl b/src/Types.jl index 6d932ecc0c..c00ef99438 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -241,6 +241,7 @@ function ConcretePJRTArray( end Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) +Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data) XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data)) @@ -420,6 +421,7 @@ function ConcreteIFRTArray( end Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data) +Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data) XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) return XLA.device(x.data) diff --git a/src/probprog/Display.jl b/src/probprog/Display.jl new file mode 100644 index 0000000000..a81992eb71 --- /dev/null +++ b/src/probprog/Display.jl @@ -0,0 +1,87 @@ +# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 +function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + + indent_vert_str = join(indent_vert) + indent_str = join(indent) + indent_last_str = join(indent_last) + + sorted_choices = sort(collect(trace.choices); by=x -> x[1]) + n = length(sorted_choices) + + if trace.retval !== nothing + n += 1 + end + + if trace.weight !== nothing + n += 1 + end + + cur = 1 + + if trace.retval !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n") + cur += 1 + end + + if trace.weight !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n") + cur += 1 + end + + for (key, value) in sorted_choices + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + + sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1]) + n += length(sorted_subtraces) + + for (key, subtrace) in sorted_subtraces + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n") + _show_pretty( + io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1) + ) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) + println(io, "ProbProgTrace:") + if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing + println(io, " (empty)") + else + _show_pretty(io, trace, 0, ()) + end +end + +function Base.show(io::IO, trace::ProbProgTrace) + if get(io, :compact, false) + choices_count = length(trace.choices) + has_retval = trace.retval !== nothing + print(io, "ProbProgTrace($(choices_count) choices") + if has_retval + print(io, ", retval=$(trace.retval), weight=$(trace.weight)") + end + print(io, ")") + else + show(io, MIME"text/plain"(), trace) + end +end diff --git a/src/probprog/FFI.jl b/src/probprog/FFI.jl new file mode 100644 index 0000000000..2f40390727 --- /dev/null +++ b/src/probprog/FFI.jl @@ -0,0 +1,768 @@ +using ..Reactant: MLIR, Profiler + +function initTrace(trace_ptr_ptr::Ptr{Ptr{Any}}) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.initTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + tr = ProbProgTrace() + _keepalive!(tr) + + unsafe_store!(trace_ptr_ptr, pointer_from_objref(tr)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addSampleToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_outputs_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addSampleToTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_outputs = unsafe_load(num_outputs_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_outputs) + width_array = unsafe_wrap(Array, width_array, num_outputs) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_outputs) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_outputs) + + vals = Any[] + for i in 1:num_outputs + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(sample_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)))) + end + end + + trace.choices[symbol] = tuple(vals...) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addSubtrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subtrace_ptr_ptr::Ptr{Ptr{Any}}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addSubtrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + subtrace = unsafe_pointer_to_objref(unsafe_load(subtrace_ptr_ptr))::ProbProgTrace + + trace.subtraces[symbol] = subtrace + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addWeightToTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addWeightToTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + trace.weight = unsafe_load(Ptr{Float64}(weight_ptr)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addRetvalToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + retval_ptr_array::Ptr{Ptr{Any}}, + num_results_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addRetvalToTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + num_results = unsafe_load(num_results_ptr) + + if num_results == 0 + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + ndims_array = unsafe_wrap(Array, ndims_array, num_results) + width_array = unsafe_wrap(Array, width_array, num_results) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_results) + retval_ptr_array = unsafe_wrap(Array, retval_ptr_array, num_results) + + vals = Any[] + for i in 1:num_results + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + retval_ptr = retval_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(retval_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(retval_ptr), Tuple(shape)))) + end + end + + trace.retval = tuple(vals...) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSampleFromConstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSampleFromConstraint"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + tostore = get(constraint, Address(symbol), nothing) + + if tostore === nothing + @ccall printf( + "No constraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if julia_type != eltype(tostore[i]) + @ccall printf( + "Type mismatch in constrained sample: %s != %s\n"::Cstring, + string(julia_type)::Cstring, + string(eltype(tostore[i]))::Cstring, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if ndims == 0 + unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + + dest_size = size(dest) + src_size = size(tostore[i]) + + if dest_size != src_size + @ccall printf( + "Shape mismatch in constrained sample: expected %zd dims, got %zd\n"::Cstring, + length(dest_size)::Csize_t, + length(src_size)::Csize_t, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + copyto!(dest, tostore[i]) + end + end + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSubconstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subconstraint_ptr_ptr::Ptr{Ptr{Any}}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSubconstraint"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + + subconstraint = Constraint() + + for (key, value) in constraint + if key.path[1] == symbol + @assert isa(key, Address) "Expected Address type for constraint key" + @assert length(key.path) > 1 "Expected composite address with length > 1" + tail_address = Address(key.path[2:end]) + subconstraint[tail_address] = value + end + end + + if isempty(subconstraint) + @ccall printf( + "No subconstraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + _keepalive!(subconstraint) + unsafe_store!(subconstraint_ptr_ptr, pointer_from_objref(subconstraint)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSampleFromTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSampleFromTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + tostore = get(trace.choices, symbol, nothing) + + if tostore === nothing + @ccall printf( + "No sample found in trace for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if julia_type != eltype(tostore[i]) + @ccall printf( + "Type mismatch in trace sample: %s != %s\n"::Cstring, + string(julia_type)::Cstring, + string(eltype(tostore[i]))::Cstring, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if ndims == 0 + unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + + dest_size = size(dest) + src_size = size(tostore[i]) + + if dest_size != src_size + @ccall printf( + "Shape mismatch in trace sample: expected %zd dims, got %zd\n"::Cstring, + length(dest_size)::Csize_t, + length(src_size)::Csize_t, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + copyto!(dest, tostore[i]) + end + end + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSubtrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subtrace_ptr_ptr::Ptr{Ptr{Any}}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSubtrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + + subtrace = get(trace.subtraces, symbol, nothing) + + if subtrace === nothing + @ccall printf( + "No subtrace found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + _keepalive!(subtrace) + unsafe_store!(subtrace_ptr_ptr, pointer_from_objref(subtrace)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getWeightFromTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getWeightFromTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + unsafe_store!(Ptr{Float64}(weight_ptr), trace.weight) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getFlattenedSamplesFromTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + num_addresses_ptr::Ptr{UInt64}, + total_symbols_ptr::Ptr{UInt64}, + address_lengths_ptr::Ptr{UInt64}, + flattened_symbols_ptr::Ptr{UInt64}, + position_ptr::Ptr{Any}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getFlattenedSamplesFromTrace"::Cstring, + Profiler.TRACE_ME_LEVEL_CRITICAL::Cint, + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("No trace found\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + num_addresses = unsafe_load(num_addresses_ptr) + total_symbols = unsafe_load(total_symbols_ptr) + + address_lengths = unsafe_wrap(Array, address_lengths_ptr, num_addresses) + flattened_symbols = unsafe_wrap(Array, flattened_symbols_ptr, total_symbols) + + addresses = Vector{Vector{Symbol}}() + symbol_idx = 1 + for i in 1:num_addresses + addr_len = address_lengths[i] + address = Symbol[] + for j in 1:addr_len + symbol_ptr_value = flattened_symbols[symbol_idx] + symbol = unsafe_pointer_to_objref(Ptr{Any}(symbol_ptr_value))::Symbol + push!(address, symbol) + symbol_idx += 1 + end + push!(addresses, address) + end + + flattened_values = Float64[] + + for address in addresses + current_trace = trace + + for (idx, symbol) in enumerate(address) + if idx < length(address) + if !haskey(current_trace.subtraces, symbol) + @ccall printf( + "No subtrace found for symbol in address path: %s\n"::Cstring, + string(symbol)::Cstring, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + current_trace = current_trace.subtraces[symbol] + else + if !haskey(current_trace.choices, symbol) + @ccall printf( + "No sample found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + sample_tuple = current_trace.choices[symbol] + + for sample_val in sample_tuple + if isa(sample_val, AbstractArray) + for val in sample_val + push!(flattened_values, Float64(val)) + end + else + push!(flattened_values, Float64(sample_val)) + end + end + end + end + end + + position_array = unsafe_wrap( + Array, Ptr{Float64}(position_ptr), length(flattened_values) + ) + copyto!(position_array, flattened_values) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function dump( + value_ptr::Ptr{Any}, + label_ptr::Ptr{UInt8}, + ndims_ptr::Ptr{UInt64}, + shape_ptr::Ptr{UInt64}, + width_ptr::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.dump"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + label = unsafe_string(label_ptr) + ndims = unsafe_load(ndims_ptr) + width = unsafe_load(width_ptr) + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + @ccall printf( + "DUMP ERROR: Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + println("═══ DUMP: $label ═══") + + if ndims == 0 + value = unsafe_load(Ptr{julia_type}(value_ptr)) + println(" Scalar ($julia_type): $value") + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + value_array = unsafe_wrap(Array, Ptr{julia_type}(value_ptr), Tuple(shape)) + + println(" Shape: $(Tuple(shape))") + println(" Type: Array{$julia_type}") + println(" Values:") + + total_elements = prod(shape) + if total_elements <= 20 + println(" ", value_array) + else + println(" [$(total_elements) elements]") + println(" min: $(minimum(value_array))") + println(" max: $(maximum(value_array))") + println(" mean: $(sum(value_array) / total_elements)") + println(" First 10: $(value_array[1:min(10, total_elements)])") + end + end + + println("═══════════════════════════════════") + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function __init__() + init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_sample_to_trace_ptr = @cfunction( + addSampleToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_subtrace_ptr = @cfunction( + addSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_subtrace::Cstring, add_subtrace_ptr::Ptr{Cvoid} + )::Cvoid + + add_weight_to_trace_ptr = @cfunction(addWeightToTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any})) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_weight_to_trace::Cstring, add_weight_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_retval_to_trace_ptr = @cfunction( + addRetvalToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ), + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_retval_to_trace::Cstring, add_retval_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + get_sample_from_constraint_ptr = @cfunction( + getSampleFromConstraint, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_sample_from_constraint::Cstring, + get_sample_from_constraint_ptr::Ptr{Cvoid}, + )::Cvoid + + get_subconstraint_ptr = @cfunction( + getSubconstraint, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_subconstraint::Cstring, get_subconstraint_ptr::Ptr{Cvoid} + )::Cvoid + + get_sample_from_trace_ptr = @cfunction( + getSampleFromTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_sample_from_trace::Cstring, + get_sample_from_trace_ptr::Ptr{Cvoid}, + )::Cvoid + + get_subtrace_ptr = @cfunction( + getSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_subtrace::Cstring, get_subtrace_ptr::Ptr{Cvoid} + )::Cvoid + + get_weight_from_trace_ptr = @cfunction( + getWeightFromTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_weight_from_trace::Cstring, + get_weight_from_trace_ptr::Ptr{Cvoid}, + )::Cvoid + + get_flattened_samples_from_trace_ptr = @cfunction( + getFlattenedSamplesFromTrace, + Cvoid, + (Ptr{Ptr{Any}}, Ptr{UInt64}, Ptr{UInt64}, Ptr{UInt64}, Ptr{UInt64}, Ptr{Any}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_flattened_samples_from_trace::Cstring, + get_flattened_samples_from_trace_ptr::Ptr{Cvoid}, + )::Cvoid + + dump_ptr = @cfunction( + dump, Cvoid, (Ptr{Any}, Ptr{UInt8}, Ptr{UInt64}, Ptr{UInt64}, Ptr{UInt64}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_dump::Cstring, dump_ptr::Ptr{Cvoid} + )::Cvoid + + return nothing +end diff --git a/src/probprog/HMC.jl b/src/probprog/HMC.jl new file mode 100644 index 0000000000..134044c3b8 --- /dev/null +++ b/src/probprog/HMC.jl @@ -0,0 +1,126 @@ +using ..Reactant: ConcreteRNumber, TracedRArray + +function hmc( + rng::AbstractRNG, + original_trace::Union{ProbProgTrace,TracedRArray{UInt64,0}}, + f::Function, + args::Vararg{Any,Nargs}; + selection::Selection, + mass=nothing, + step_size=nothing, + num_steps=nothing, + initial_momentum=nothing, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "hmc") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + trace_val = if original_trace isa TracedRArray{UInt64,0} + MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [original_trace.mlir_data]; outputs=[trace_ty] + ), + 1, + ) + else + # First iteration: promote a ProbProgTrace to tensor + promoted = to_trace_tensor(original_trace) + MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(promoted)]; outputs=[trace_ty] + ), + 1, + ) + end + + selection_attr = MLIR.IR.Attribute[] + for address in selection + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(selection_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + accepted_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Bool)) + + alg_attr = @ccall MLIR.API.mlir_c.enzymeMCMCAlgorithmAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, + 0::Int32, # 0 = HMC + )::MLIR.IR.Attribute + + mass_val = nothing + if !isnothing(mass) + mass_val = TracedUtils.get_mlir_data(mass) + end + + step_size_val = nothing + if !isnothing(step_size) + step_size_val = TracedUtils.get_mlir_data(step_size) + end + + num_steps_val = nothing + if !isnothing(num_steps) + num_steps_val = TracedUtils.get_mlir_data(num_steps) + end + + initial_momentum_val = nothing + if !isnothing(initial_momentum) + initial_momentum_val = TracedUtils.get_mlir_data(initial_momentum) + end + + hmc_op = MLIR.Dialects.enzyme.mcmc( + inputs, + trace_val, + mass_val; + step_size=step_size_val, + num_steps=num_steps_val, + initial_momentum=initial_momentum_val, + new_trace=trace_ty, + accepted=accepted_ty, + output_rng_state=out_tys[1], # by convention + alg=alg_attr, + fn=fn_attr, + selection=MLIR.IR.Attribute(selection_attr), + ) + + # (new_trace, accepted, output_rng_state) + process_probprog_outputs( + hmc_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2, true + ) + + new_trace_val = MLIR.IR.result(hmc_op, 1) + new_trace_ptr = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [new_trace_val]; outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))] + ), + 1, + ) + + new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ()) + accepted = TracedRArray{Bool,0}((), MLIR.IR.result(hmc_op, 2), ()) + + return new_trace, accepted, result +end diff --git a/src/probprog/MH.jl b/src/probprog/MH.jl new file mode 100644 index 0000000000..93d5bbca79 --- /dev/null +++ b/src/probprog/MH.jl @@ -0,0 +1,95 @@ +using ..Reactant: ConcreteRNumber, TracedRArray + +function mh( + rng::AbstractRNG, + original_trace::Union{ProbProgTrace,TracedRArray{UInt64,0}}, + f::Function, + args::Vararg{Any,Nargs}; + selection::Selection, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "mh") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + if original_trace isa TracedRArray{UInt64,0} + # Use MLIR data from previous iteration + trace_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [original_trace.mlir_data]; outputs=[trace_ty] + ), + 1, + ) + else + # First iteration: create constant from pointer + promoted = to_trace_tensor(original_trace) + trace_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(promoted)]; outputs=[trace_ty] + ), + 1, + ) + end + + selection_attr = MLIR.IR.Attribute[] + for address in selection + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(selection_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + accepted_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Bool)) + + mh_op = MLIR.Dialects.enzyme.mh( + inputs, + trace_val; + new_trace=trace_ty, + accepted=accepted_ty, + output_rng_state=out_tys[1], # by convention + fn=fn_attr, + selection=MLIR.IR.Attribute(selection_attr), + ) + + # Return (new_trace, accepted, output_rng_state) + process_probprog_outputs( + mh_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2, true + ) + + new_trace_val = MLIR.IR.result(mh_op, 1) + new_trace_ptr = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [new_trace_val]; outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))] + ), + 1, + ) + + new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ()) + accepted = TracedRArray{Bool,0}((), MLIR.IR.result(mh_op, 2), ()) + + return new_trace, accepted, result +end + +const metropolis_hastings = mh diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl new file mode 100644 index 0000000000..b3e1f82e9f --- /dev/null +++ b/src/probprog/Modeling.jl @@ -0,0 +1,257 @@ +using ..Reactant: MLIR, TracedUtils, AbstractRNG, TracedRArray, ConcreteRNumber +using ..Compiler: @jit, @compile + +include("Utils.jl") + +function sample( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, +) where {Nargs} + args_with_rng = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function( + f, args_with_rng, "sample" + ) + + (; result, linear_args, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + symbol_attr = @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64 + )::MLIR.IR.Attribute + + # Construct logpdf attribute if `logpdf` function is provided. + logpdf_attr = nothing + if logpdf isa Function + samples = f(args_with_rng...) + + # Assume that logpdf parameters follow `(sample, args...)` convention. + logpdf_args = (samples, args...) + + logpdf_mlir = TracedUtils.make_mlir_fn( + logpdf, + logpdf_args, + (), + string(logpdf), + false; + do_transpose=false, + args_in_result=:result, + ) + + logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name") + logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(logpdf_sym)) + end + + sample_op = MLIR.Dialects.enzyme.sample( + inputs; + outputs=out_tys, + fn=fn_attr, + logpdf=logpdf_attr, + symbol=symbol_attr, + name=Base.String(symbol), + ) + + process_probprog_outputs( + sample_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix + ) + + return result +end + +function untraced_call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + args_with_rng = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function( + f, args_with_rng, "call" + ) + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + call_op = MLIR.Dialects.enzyme.untracedCall(inputs; outputs=out_tys, fn=fn_attr) + + process_probprog_outputs( + call_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix + ) + + return result +end + +# Gen-like helper function. +function simulate_(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + trace = nothing + + compiled_fn = @compile optimize = :probprog simulate(rng, f, args...) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer begin + t, _, _ = compiled_fn(rng, f, args...) + trace = from_trace_tensor(t) + end + + return trace, trace.weight +end + +function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "simulate") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + simulate_op = MLIR.Dialects.enzyme.simulate( + inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fn_attr + ) + + process_probprog_outputs( + simulate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + ) + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(simulate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(simulate_op, 2), ()) + + return trace, weight, result +end + +# Gen-like helper function. +function generate_( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + constraint::Constraint=Constraint(), +) where {Nargs} + trace = nothing + + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) + + constrained_addresses = extract_addresses(constraint) + + function wrapper_fn(rng, constraint_ptr, args...) + return generate(rng, f, args...; constraint_ptr, constrained_addresses) + end + + compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint begin + t, _, _ = compiled_fn(rng, constraint_ptr, args...) + trace = from_trace_tensor(t) + end + + return trace, trace.weight +end + +function generate( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + constraint_ptr::TracedRNumber, + constrained_addresses::Set{Address}, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "generate") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + constraint_ty = @ccall MLIR.API.mlir_c.enzymeConstraintTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + constraint_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(constraint_ptr)]; outputs=[constraint_ty] + ), + 1, + ) + + constrained_addresses_attr = MLIR.IR.Attribute[] + for address in constrained_addresses + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(constrained_addresses_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + generate_op = MLIR.Dialects.enzyme.generate( + inputs, + constraint_val; + trace=trace_ty, + weight=weight_ty, + outputs=out_tys, + fn=fn_attr, + constrained_addresses=MLIR.IR.Attribute(constrained_addresses_attr), + ) + + process_probprog_outputs( + generate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + ) + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(generate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(generate_op, 2), ()) + + return trace, weight, result +end diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl new file mode 100644 index 0000000000..23de59676f --- /dev/null +++ b/src/probprog/ProbProg.jl @@ -0,0 +1,28 @@ +module ProbProg + +using ..Reactant: + MLIR, TracedUtils, AbstractRNG, TracedRArray, TracedRNumber, ConcreteRNumber +using ..Compiler: @jit, @compile + +include("Types.jl") +include("FFI.jl") +include("Modeling.jl") +include("Display.jl") +include("MH.jl") +include("HMC.jl") + +# Types. +export ProbProgTrace, Constraint, Selection, Address + +# Utility functions. +export get_choices, select +export to_trace_tensor, from_trace_tensor +export to_constraint_tensor, from_constraint_tensor + +# Core MLIR ops. +export sample, untraced_call, simulate, generate, mh, hmc + +# Gen-like helper functions. +export simulate_, generate_ + +end diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl new file mode 100644 index 0000000000..98f189d9a0 --- /dev/null +++ b/src/probprog/Types.jl @@ -0,0 +1,77 @@ +using Base: ReentrantLock + +mutable struct ProbProgTrace + choices::Dict{Symbol,Any} + retval::Any + weight::Any + subtraces::Dict{Symbol,Any} + + function ProbProgTrace() + return new(Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}()) + end +end + +struct Address + path::Vector{Symbol} + + Address(path::Vector{Symbol}) = new(path) +end + +Address(sym::Symbol) = Address([sym]) +Address(syms::Symbol...) = Address([syms...]) + +Base.:(==)(a::Address, b::Address) = a.path == b.path +Base.hash(a::Address, h::UInt) = hash(a.path, h) + +mutable struct Constraint <: AbstractDict{Address,Any} + dict::Dict{Address,Any} + + function Constraint(pairs::Pair...) + dict = Dict{Address,Any}() + for pair in pairs + symbols = Symbol[] + current = pair + while isa(current, Pair) && isa(current.first, Symbol) + push!(symbols, current.first) + current = current.second + end + dict[Address(symbols...)] = current + end + return new(dict) + end + + Constraint() = new(Dict{Address,Any}()) + Constraint(d::Dict{Address,Any}) = new(d) +end + +Base.getindex(c::Constraint, k::Address) = c.dict[k] +Base.setindex!(c::Constraint, v, k::Address) = (c.dict[k] = v) +Base.delete!(c::Constraint, k::Address) = delete!(c.dict, k) +Base.keys(c::Constraint) = keys(c.dict) +Base.values(c::Constraint) = values(c.dict) +Base.iterate(c::Constraint) = iterate(c.dict) +Base.iterate(c::Constraint, state) = iterate(c.dict, state) +Base.length(c::Constraint) = length(c.dict) +Base.isempty(c::Constraint) = isempty(c.dict) +Base.haskey(c::Constraint, k::Address) = haskey(c.dict, k) +Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default) + +extract_addresses(constraint::Constraint) = Set(keys(constraint)) + +const Selection = Set{Address} + +const _probprog_ref_lock = ReentrantLock() +const _probprog_refs = IdDict() + +function _keepalive!(tr::Any) + lock(_probprog_ref_lock) + try + _probprog_refs[tr] = tr + finally + unlock(_probprog_ref_lock) + end + return tr +end + +get_choices(trace::ProbProgTrace) = trace.choices +select(addrs::Address...) = Set{Address}([addrs...]) diff --git a/src/probprog/Utils.jl b/src/probprog/Utils.jl new file mode 100644 index 0000000000..7e8c197d5a --- /dev/null +++ b/src/probprog/Utils.jl @@ -0,0 +1,154 @@ +using ..Reactant: MLIR, TracedUtils, Ops, TracedRArray +import ..Reactant: promote_to + +""" + process_probprog_function(f, args_with_rng, op_name) + +This function handles the probprog argument convention where: +- **Index 1**: RNG state +- **Index 2**: Function `f` (when wrapped) +- **Index 3+**: Remaining arguments + +This wrapper ensures the RNG state is threaded through as the first result, +followed by the actual function results. +""" +function process_probprog_function(f, args_with_rng, op_name) + argprefix = gensym(op_name * "arg") + resprefix = gensym(op_name * "result") + resargprefix = gensym(op_name * "resarg") + + wrapper_fn = (all_args...) -> begin + res = f(all_args...) + (all_args[1], (res isa Tuple ? res : (res,))...) + end + + mlir_fn_res = TracedUtils.make_mlir_fn( + wrapper_fn, + args_with_rng, + (), + string(f), + false; + do_transpose=false, + args_in_result=:result, + argprefix, + resprefix, + resargprefix, + ) + + return mlir_fn_res, argprefix, resprefix, resargprefix +end + +""" + process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + +This function handles the probprog argument convention where: +- **Index 1**: RNG state +- **Index 2**: Function `f` (when `fnwrap` is true) +- **Index 3+**: Other arguments +""" +function process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 2 && fnwrap + TracedUtils.push_val!(inputs, f, path[3:end]) + else + if fnwrap && idx > 1 + idx -= 1 + end + TracedUtils.push_val!(inputs, args_with_rng[idx], path[3:end]) + end + end + return inputs +end + +""" + process_probprog_outputs(op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix, start_idx=0, rng_only=false) + +This function handles the probprog argument convention where: +- **Index 1**: RNG state +- **Index 2**: Function `f` (when `fnwrap` is true) +- **Index 3+**: Other arguments + +When setting results, the function checks: +1. If result path matches `resprefix`, store in `result` +2. If result path matches `argprefix`, store in `args_with_rng` (adjust indices for wrapped function) + +`start_idx` varies depending on the ProbProg operation: +- `sample` and `untraced_call` return only function outputs: + Use `start_idx=0`: `linear_results[i]` corresponds to `op.result[i]` +- `simulate` and `generate` return trace, weight, then outputs: + Use `start_idx=2`: `linear_results[i]` corresponds to `op.result[i+2]` +- `mh` and `regenerate` return trace, accepted/weight, rng_state (no model outputs): + Use `start_idx=2, rng_only=true`: only process first result (rng_state) + +`rng_only`: When true, only process the first result (RNG state), skipping model outputs +""" +function process_probprog_outputs( + op, + linear_results, + result, + f, + args_with_rng, + fnwrap, + resprefix, + argprefix, + start_idx=0, + rng_only=false, +) + num_to_process = rng_only ? 1 : length(linear_results) + + for i in 1:num_to_process + res = linear_results[i] + resv = MLIR.IR.result(op, i + start_idx) + + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + end + + if TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if fnwrap && idx == 2 + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap && idx > 2 + idx -= 1 + end + TracedUtils.set!(args_with_rng[idx], path[3:end], resv) + end + end + + if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) + TracedUtils.set!(res, (), resv) + end + end +end + +to_trace_tensor(t::ProbProgTrace) = promote_to(TracedRArray{UInt64,0}, t) + +function from_trace_tensor(trace_tensor) + while !isready(trace_tensor) + yield() + end + return unsafe_pointer_to_objref(Ptr{Any}(Array(trace_tensor)[1]))::ProbProgTrace +end + +function promote_to(::Type{TracedRArray{UInt64,0}}, t::ProbProgTrace) + ptr = reinterpret(UInt64, pointer_from_objref(t)) + return Ops.fill(ptr, Int64[]) +end + +to_constraint_tensor(c::Constraint) = promote_to(TracedRArray{UInt64,0}, c) + +function from_constraint_tensor(constraint_tensor) + while !isready(constraint_tensor) + yield() + end + return unsafe_pointer_to_objref(Ptr{Any}(Array(constraint_tensor)[1]))::Constraint +end + +function promote_to(::Type{TracedRArray{UInt64,0}}, c::Constraint) + ptr = reinterpret(UInt64, pointer_from_objref(c)) + return Ops.fill(ptr, Int64[]) +end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl new file mode 100644 index 0000000000..f5fa4fea38 --- /dev/null +++ b/test/probprog/generate.jl @@ -0,0 +1,142 @@ +using Reactant, Test, Random, Statistics +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function two_normals(rng, μ, σ, shape) + _, x = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:x, logpdf=normal_logpdf) + _, y = ProbProg.sample(rng, normal, x, σ, shape; symbol=:y, logpdf=normal_logpdf) + return y +end + +function nested_model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, two_normals, s, σ, shape; symbol=:t) + _, u = ProbProg.sample(rng, two_normals, t, σ, shape; symbol=:u) + return u +end + +@testset "Generate" begin + @testset "unconstrained" begin + shape = (1000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + trace, weight = ProbProg.generate_(rng, model, μ, σ, shape) + @test mean(trace.retval[1]) ≈ 0.0 atol = 0.05 rtol = 0.05 + end + + @testset "constrained" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint(:s => (fill(0.1, shape),)) + + trace, weight = ProbProg.generate_(rng, model, μ, σ, shape; constraint) + + @test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1] + + expected_weight = + normal_logpdf(constraint[ProbProg.Address(:s)][1], 0.0, 1.0, shape) + + normal_logpdf( + trace.choices[:t][1], constraint[ProbProg.Address(:s)][1], 1.0, shape + ) + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "composite addresses" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint( + :s => (fill(0.1, shape),), + :t => :x => (fill(0.2, shape),), + :u => :y => (fill(0.3, shape),), + ) + + trace, weight = ProbProg.generate_(rng, nested_model, μ, σ, shape; constraint) + + @test trace.choices[:s][1] == fill(0.1, shape) + @test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape) + @test trace.subtraces[:u].choices[:y][1] == fill(0.3, shape) + + s_weight = normal_logpdf(fill(0.1, shape), 0.0, 1.0, shape) + tx_weight = normal_logpdf(fill(0.2, shape), fill(0.1, shape), 1.0, shape) + ty_weight = normal_logpdf( + trace.subtraces[:t].choices[:y][1], fill(0.2, shape), 1.0, shape + ) + ux_weight = normal_logpdf( + trace.subtraces[:u].choices[:x][1], + trace.subtraces[:t].choices[:y][1], + 1.0, + shape, + ) + uy_weight = normal_logpdf( + fill(0.3, shape), trace.subtraces[:u].choices[:x][1], 1.0, shape + ) + + expected_weight = s_weight + tx_weight + ty_weight + ux_weight + uy_weight + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "compiled" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint1 = ProbProg.Constraint(:s => (fill(0.1, shape),)) + + constrained_addresses = ProbProg.extract_addresses(constraint1) + + constraint_ptr1 = Reactant.ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint1)) + ) + + wrapper_fn(rng, constraint_ptr, μ, σ) = ProbProg.generate( + rng, model, μ, σ, shape; constraint_ptr, constrained_addresses + ) + + compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr1, μ, σ) + + trace1 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint1 begin + trace1, _ = compiled_fn(rng, constraint_ptr1, μ, σ) + trace1 = ProbProg.from_trace_tensor(trace1) + end + + constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),)) + constraint_ptr2 = Reactant.ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint2)) + ) + + trace2 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint2 begin + trace2, _ = compiled_fn(rng, constraint_ptr2, μ, σ) + trace2 = ProbProg.from_trace_tensor(trace2) + end + + @test trace1.choices[:s][1] != trace2.choices[:s][1] + end +end diff --git a/test/probprog/hmc.jl b/test/probprog/hmc.jl new file mode 100644 index 0000000000..03c3394a40 --- /dev/null +++ b/test/probprog/hmc.jl @@ -0,0 +1,138 @@ +using Reactant, Test, Random +using Statistics +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, xs) + _, param_a = ProbProg.sample( + rng, normal, 0.0, 5.0, (1,); symbol=:param_a, logpdf=normal_logpdf + ) + _, param_b = ProbProg.sample( + rng, normal, 0.0, 5.0, (1,); symbol=:param_b, logpdf=normal_logpdf + ) + + _, ys_a = ProbProg.sample( + rng, normal, param_a .+ xs[1:5], 0.5, (5,); symbol=:ys_a, logpdf=normal_logpdf + ) + + _, ys_b = ProbProg.sample( + rng, normal, param_b .+ xs[6:10], 0.5, (5,); symbol=:ys_b, logpdf=normal_logpdf + ) + + return vcat(ys_a, ys_b) +end + +function hmc_program( + rng, + model, + xs, + step_size, + num_steps, + mass, + initial_momentum, + constraint_ptr, + constrained_addresses, +) + t, _, _ = ProbProg.generate( + rng, + model, + xs; + constraint_ptr=constraint_ptr, + constrained_addresses=constrained_addresses, + ) + + t, accepted, _ = ProbProg.hmc( + rng, + t, + model, + xs; + selection=ProbProg.select(ProbProg.Address(:param_a), ProbProg.Address(:param_b)), + mass=mass, + step_size=step_size, + num_steps=num_steps, + initial_momentum=initial_momentum, + ) + + return t, accepted +end + +@testset "hmc" begin + seed = Reactant.to_rarray(UInt64[1, 5]) + rng = ReactantRNG(seed) + + xs = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5] + ys_a = [-2.3, -1.6, -0.4, 0.6, 1.4] + ys_b = [-2.6, -1.4, -0.6, 0.4, 1.6] + obs = ProbProg.Constraint( + :param_a => ([0.0],), :param_b => ([0.0],), :ys_a => (ys_a,), :ys_b => (ys_b,) + ) + constrained_addresses = ProbProg.extract_addresses(obs) + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(obs))) + + step_size = ConcreteRNumber(0.001) + num_steps_compile = ConcreteRNumber(1000) + num_steps_run = ConcreteRNumber(40000000) + mass = nothing + initial_momentum = ConcreteRArray([0.0, 0.0]) + + code = @code_hlo optimize = :probprog hmc_program( + rng, + model, + xs, + step_size, + num_steps_compile, + mass, + initial_momentum, + constraint_ptr, + constrained_addresses, + ) + @test contains(repr(code), "enzyme_probprog_get_flattened_samples_from_trace") + @test contains(repr(code), "enzyme_probprog_get_weight_from_trace") + @test !contains(repr(code), "enzyme.mh") + @test !contains(repr(code), "enzyme.mcmc") + + compile_time_s = @elapsed begin + compiled_fn = @compile optimize = :probprog hmc_program( + rng, + model, + xs, + step_size, + num_steps_compile, + mass, + initial_momentum, + constraint_ptr, + constrained_addresses, + ) + end + println("HMC compile time: $(round(compile_time_s * 1000, digits=2)) ms") + + seed_buffer = only(rng.seed.data).buffer + trace = nothing + GC.@preserve seed_buffer obs begin + run_time_s = @elapsed begin + trace_ptr, _ = compiled_fn( + rng, + model, + xs, + step_size, + num_steps_run, + mass, + initial_momentum, + constraint_ptr, + constrained_addresses, + ) + trace = ProbProg.from_trace_tensor(trace_ptr) + end + println("HMC run time: $(round(run_time_s * 1000, digits=2)) ms") + end + + # NumPyro results + @test only(trace.choices[:param_a])[1] ≈ 0.01327671 rtol = 1e-6 + @test only(trace.choices[:param_b])[1] ≈ -0.01965474 rtol = 1e-6 +end diff --git a/test/probprog/mh.jl b/test/probprog/mh.jl new file mode 100644 index 0000000000..b05021827d --- /dev/null +++ b/test/probprog/mh.jl @@ -0,0 +1,114 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, xs) + _, slope = ProbProg.sample( + rng, normal, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf + ) + _, intercept = ProbProg.sample( + rng, normal, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf + ) + + _, ys = ProbProg.sample( + rng, + normal, + slope .* xs .+ intercept, + 1.0, + (length(xs),); + symbol=:ys, + logpdf=normal_logpdf, + ) + + return ys +end + +function mh_program(rng, model, xs, num_iters, constraint_ptr, constrained_addresses) + init_trace, _, _ = ProbProg.generate( + rng, + model, + xs; + constraint_ptr=constraint_ptr, + constrained_addresses=constrained_addresses, + ) + + trace_ptr = init_trace + @trace for _ in 1:num_iters + trace_ptr, _ = ProbProg.mh( + rng, trace_ptr, model, xs; selection=ProbProg.select(ProbProg.Address(:slope)) + ) + trace_ptr, _ = ProbProg.mh( + rng, + trace_ptr, + model, + xs; + selection=ProbProg.select(ProbProg.Address(:intercept)), + ) + end + + return trace_ptr +end + +@testset "linear_regression" begin + @testset "simulate" begin + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + xs_r = Reactant.to_rarray(xs) + + trace, _ = ProbProg.simulate_(rng, model, xs_r) + + @test haskey(trace.choices, :slope) + @test haskey(trace.choices, :intercept) + @test haskey(trace.choices, :ys) + end + + @testset "inference" begin + seed = Reactant.to_rarray(UInt64[1, 5]) + rng = ReactantRNG(seed) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] + obs = ProbProg.Constraint(:ys => (ys,)) + num_iters = ConcreteRNumber(10000) + constrained_addresses = ProbProg.extract_addresses(obs) + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(obs))) + + code = @code_hlo optimize = :probprog mh_program( + rng, model, xs, 10000, constraint_ptr, constrained_addresses + ) + @test contains(repr(code), "enzyme_probprog_get_sample_from_trace") + @test contains(repr(code), "enzyme_probprog_get_weight_from_trace") + @test !contains(repr(code), "enzyme.mh") + + compiled_fn = @compile optimize = :probprog mh_program( + rng, model, xs, num_iters, constraint_ptr, constrained_addresses + ) + + trace = nothing + seed_buffer = only(rng.seed.data).buffer + num_iters = ConcreteRNumber(1000) + GC.@preserve seed_buffer obs begin + trace_ptr = compiled_fn( + rng, model, xs, num_iters, constraint_ptr, constrained_addresses + ) + trace = ProbProg.from_trace_tensor(trace_ptr) + end + + slope = only(trace.choices[:slope])[1] + intercept = only(trace.choices[:intercept])[1] + @show slope, intercept + + @test slope ≈ -2.0 rtol = 0.1 + @test intercept ≈ 10.0 rtol = 0.1 + end +end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl new file mode 100644 index 0000000000..b7889c46dd --- /dev/null +++ b/test/probprog/sample.jl @@ -0,0 +1,88 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function one_sample(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape) + return s +end + +function two_samples(rng, μ, σ, shape) + _ = ProbProg.sample(rng, normal, μ, σ, shape) + _, t = ProbProg.sample(rng, normal, μ, σ, shape) + return t +end + +function compose(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape) + _, t = ProbProg.sample(rng, normal, s, σ, shape) + return t +end + +@testset "test" begin + @testset "normal_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample(rng, normal, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "two_samples_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample(rng, two_samples, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "compose" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.untraced_call( + rng, compose, μ, σ, shape + ) + @test contains(repr(before), "enzyme.sample") + + after = @code_hlo optimize = :probprog ProbProg.untraced_call( + rng, compose, μ, σ, shape + ) + @test !contains(repr(after), "enzyme.sample") + end + + @testset "rng_state" begin + shape = (10,) + + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + rng1 = ReactantRNG(copy(seed)) + + _, X = @jit optimize = :probprog ProbProg.untraced_call( + rng1, one_sample, μ, σ, shape + ) + @test !all(rng1.seed .== seed) + + rng2 = ReactantRNG(copy(seed)) + _, Y = @jit optimize = :probprog ProbProg.untraced_call( + rng2, two_samples, μ, σ, shape + ) + + @test !all(rng2.seed .== seed) + @test !all(rng2.seed .== rng1.seed) + + @test !all(X .≈ Y) + end +end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl new file mode 100644 index 0000000000..3be45bc256 --- /dev/null +++ b/test/probprog/simulate.jl @@ -0,0 +1,115 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function product_two_normals(rng, μ, σ, shape) + _, a = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) + _, b = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:b, logpdf=normal_logpdf) + return a .* b +end + +function model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function model2(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:s) + _, t = ProbProg.sample(rng, product_two_normals, s, σ, shape; symbol=:t) + return t +end + +@testset "Simulate" begin + @testset "hlo" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.simulate(rng, model, μ, σ, shape) + @test contains(repr(before), "enzyme.simulate") + + after = @code_hlo optimize = :probprog ProbProg.simulate(rng, model, μ, σ, shape) + @test !contains(repr(after), "enzyme.simulate") + @test !contains(repr(after), "enzyme.addSampleToTrace") + @test !contains(repr(after), "enzyme.addWeightToTrace") + @test !contains(repr(after), "enzyme.addRetvalToTrace") + end + + @testset "normal_simulate" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate_(rng, model, μ, σ, shape) + + @test size(trace.retval[1]) == shape + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + @test trace.weight isa Float64 + end + + @testset "simple_fake" begin + op(_, x, y) = x * y' + logpdf(res, _, _) = sum(res) + function fake_model(rng, x, y) + _, res = ProbProg.sample(rng, op, x, y; symbol=:matmul, logpdf=logpdf) + return res + end + + x = reshape(collect(Float64, 1:12), (4, 3)) + y = reshape(collect(Float64, 1:12), (4, 3)) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + trace, weight = ProbProg.simulate_(rng, fake_model, x_ra, y_ra) + + @test Array(trace.retval[1]) == op(rng, x, y) + @test haskey(trace.choices, :matmul) + @test trace.choices[:matmul][1] == op(rng, x, y) + @test trace.weight == logpdf(op(rng, x, y), x, y) + end + + @testset "submodel_fake" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate_(rng, model2, μ, σ, shape) + + @test size(trace.retval[1]) == shape + + @test length(trace.choices) == 2 + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + + @test length(trace.subtraces) == 2 + @test haskey(trace.subtraces[:s].choices, :a) + @test haskey(trace.subtraces[:s].choices, :b) + @test haskey(trace.subtraces[:t].choices, :a) + @test haskey(trace.subtraces[:t].choices, :b) + + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + + @test trace.weight isa Float64 + + @test trace.weight ≈ trace.subtraces[:s].weight + trace.subtraces[:t].weight + end +end diff --git a/test/runtests.jl b/test/runtests.jl index bc10705e43..86f9c77a27 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -73,4 +73,11 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Lux Integration" include("nn/lux.jl") end end + + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "probprog" + @safetestset "ProbProg Sample" include("probprog/sample.jl") + @safetestset "ProbProg Simulate" include("probprog/simulate.jl") + @safetestset "ProbProg Generate" include("probprog/generate.jl") + @safetestset "ProbProg HMC" include("probprog/hmc.jl") + end end From 61f49a141c845bc9fa4e99287942bb4dfd84661d Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sun, 26 Oct 2025 16:55:17 -0500 Subject: [PATCH 02/10] test --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 86f9c77a27..ffa6d6cba0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -78,6 +78,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "ProbProg Sample" include("probprog/sample.jl") @safetestset "ProbProg Simulate" include("probprog/simulate.jl") @safetestset "ProbProg Generate" include("probprog/generate.jl") + @safetestset "ProbProg MH" include("probprog/mh.jl") @safetestset "ProbProg HMC" include("probprog/hmc.jl") end end From bac63f1e1d37b772a9d19ba8a34f02ca176ce32a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 30 Oct 2025 00:30:16 -0500 Subject: [PATCH 03/10] refactored with callcache --- src/probprog/HMC.jl | 34 +++---- src/probprog/MH.jl | 34 +++---- src/probprog/Modeling.jl | 144 ++++++++++++++--------------- src/probprog/Utils.jl | 190 +++++++++++++++++++++++--------------- test/probprog/sample.jl | 15 ++- test/probprog/simulate.jl | 39 ++++++++ 6 files changed, 275 insertions(+), 181 deletions(-) diff --git a/src/probprog/HMC.jl b/src/probprog/HMC.jl index 134044c3b8..501448d92b 100644 --- a/src/probprog/HMC.jl +++ b/src/probprog/HMC.jl @@ -12,17 +12,10 @@ function hmc( initial_momentum=nothing, ) where {Nargs} args = (rng, args...) - mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "hmc") - - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args, "hmc" + ) + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( MLIR.IR.context()::MLIR.API.MlirContext @@ -92,7 +85,7 @@ function hmc( end hmc_op = MLIR.Dialects.enzyme.mcmc( - inputs, + mlir_caller_args, trace_val, mass_val; step_size=step_size_val, @@ -100,15 +93,24 @@ function hmc( initial_momentum=initial_momentum_val, new_trace=trace_ty, accepted=accepted_ty, - output_rng_state=out_tys[1], # by convention + output_rng_state=mlir_result_types[1], # by convention alg=alg_attr, fn=fn_attr, selection=MLIR.IR.Attribute(selection_attr), ) # (new_trace, accepted, output_rng_state) - process_probprog_outputs( - hmc_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2, true + traced_result = process_probprog_outputs( + hmc_op, + linear_results, + traced_result, + f, + args, + fnwrapped, + resprefix, + argprefix, + 2, + true, ) new_trace_val = MLIR.IR.result(hmc_op, 1) @@ -122,5 +124,5 @@ function hmc( new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ()) accepted = TracedRArray{Bool,0}((), MLIR.IR.result(hmc_op, 2), ()) - return new_trace, accepted, result + return new_trace, accepted, traced_result end diff --git a/src/probprog/MH.jl b/src/probprog/MH.jl index 93d5bbca79..cbe51daa17 100644 --- a/src/probprog/MH.jl +++ b/src/probprog/MH.jl @@ -8,17 +8,10 @@ function mh( selection::Selection, ) where {Nargs} args = (rng, args...) - mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "mh") - - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args, "mh" + ) + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( MLIR.IR.context()::MLIR.API.MlirContext @@ -64,18 +57,27 @@ function mh( accepted_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Bool)) mh_op = MLIR.Dialects.enzyme.mh( - inputs, + mlir_caller_args, trace_val; new_trace=trace_ty, accepted=accepted_ty, - output_rng_state=out_tys[1], # by convention + output_rng_state=mlir_result_types[1], # by convention fn=fn_attr, selection=MLIR.IR.Attribute(selection_attr), ) # Return (new_trace, accepted, output_rng_state) - process_probprog_outputs( - mh_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2, true + traced_result = process_probprog_outputs( + mh_op, + linear_results, + traced_result, + f, + args, + fnwrapped, + resprefix, + argprefix, + 2, + true, ) new_trace_val = MLIR.IR.result(mh_op, 1) @@ -89,7 +91,7 @@ function mh( new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ()) accepted = TracedRArray{Bool,0}((), MLIR.IR.result(mh_op, 2), ()) - return new_trace, accepted, result + return new_trace, accepted, traced_result end const metropolis_hastings = mh diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index b3e1f82e9f..85e2f88675 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -11,86 +11,76 @@ function sample( logpdf::Union{Nothing,Function}=nothing, ) where {Nargs} args_with_rng = (rng, args...) - mlir_fn_res, argprefix, resprefix, _ = process_probprog_function( + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( f, args_with_rng, "sample" ) - (; result, linear_args, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - inputs = process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - - sym = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) symbol_attr = @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64 )::MLIR.IR.Attribute - # Construct logpdf attribute if `logpdf` function is provided. logpdf_attr = nothing if logpdf isa Function samples = f(args_with_rng...) - # Assume that logpdf parameters follow `(sample, args...)` convention. + # Logpdf calling convention: `(sample, args...)` (no rng state) logpdf_args = (samples, args...) - logpdf_mlir = TracedUtils.make_mlir_fn( - logpdf, - logpdf_args, - (), - string(logpdf), - false; - do_transpose=false, - args_in_result=:result, + logpdf_attr = MLIR.IR.FlatSymbolRefAttribute( + process_probprog_function(logpdf, logpdf_args, "logpdf", false).f_name ) - - logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name") - logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(logpdf_sym)) end sample_op = MLIR.Dialects.enzyme.sample( - inputs; - outputs=out_tys, + mlir_caller_args; + outputs=mlir_result_types, fn=fn_attr, logpdf=logpdf_attr, symbol=symbol_attr, name=Base.String(symbol), ) - process_probprog_outputs( - sample_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix + traced_result = process_probprog_outputs( + sample_op, + linear_results, + traced_result, + f, + args_with_rng, + fnwrapped, + resprefix, + argprefix, ) - return result + return traced_result end function untraced_call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} args_with_rng = (rng, args...) - mlir_fn_res, argprefix, resprefix, _ = process_probprog_function( - f, args_with_rng, "call" - ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - inputs = process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args_with_rng, "untraced_call" + ) - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) - call_op = MLIR.Dialects.enzyme.untracedCall(inputs; outputs=out_tys, fn=fn_attr) + call_op = MLIR.Dialects.enzyme.untracedCall( + mlir_caller_args; outputs=mlir_result_types, fn=fn_attr + ) - process_probprog_outputs( - call_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix + traced_result = process_probprog_outputs( + call_op, + linear_results, + traced_result, + f, + args_with_rng, + fnwrapped, + resprefix, + argprefix, ) - return result + return traced_result end # Gen-like helper function. @@ -110,17 +100,10 @@ end function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} args = (rng, args...) - mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "simulate") - - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args, "simulate" + ) + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( MLIR.IR.context()::MLIR.API.MlirContext @@ -128,11 +111,23 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) simulate_op = MLIR.Dialects.enzyme.simulate( - inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fn_attr + mlir_caller_args; + trace=trace_ty, + weight=weight_ty, + outputs=mlir_result_types, + fn=fn_attr, ) - process_probprog_outputs( - simulate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + traced_result = process_probprog_outputs( + simulate_op, + linear_results, + traced_result, + f, + args, + fnwrapped, + resprefix, + argprefix, + 2, ) trace = MLIR.IR.result( @@ -146,7 +141,7 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where trace = TracedRArray{UInt64,0}((), trace, ()) weight = TracedRArray{Float64,0}((), MLIR.IR.result(simulate_op, 2), ()) - return trace, weight, result + return trace, weight, traced_result end # Gen-like helper function. @@ -185,17 +180,12 @@ function generate( constrained_addresses::Set{Address}, ) where {Nargs} args = (rng, args...) - mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "generate") - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args, "generate" + ) - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) constraint_ty = @ccall MLIR.API.mlir_c.enzymeConstraintTypeGet( MLIR.IR.context()::MLIR.API.MlirContext @@ -229,17 +219,25 @@ function generate( weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) generate_op = MLIR.Dialects.enzyme.generate( - inputs, + mlir_caller_args, constraint_val; trace=trace_ty, weight=weight_ty, - outputs=out_tys, + outputs=mlir_result_types, fn=fn_attr, constrained_addresses=MLIR.IR.Attribute(constrained_addresses_attr), ) - process_probprog_outputs( - generate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + traced_result = process_probprog_outputs( + generate_op, + linear_results, + traced_result, + f, + args, + fnwrapped, + resprefix, + argprefix, + 2, ) trace = MLIR.IR.result( @@ -253,5 +251,5 @@ function generate( trace = TracedRArray{UInt64,0}((), trace, ()) weight = TracedRArray{Float64,0}((), MLIR.IR.result(generate_op, 2), ()) - return trace, weight, result + return trace, weight, traced_result end diff --git a/src/probprog/Utils.jl b/src/probprog/Utils.jl index 7e8c197d5a..fa05c2d812 100644 --- a/src/probprog/Utils.jl +++ b/src/probprog/Utils.jl @@ -1,8 +1,21 @@ -using ..Reactant: MLIR, TracedUtils, Ops, TracedRArray +using ..Reactant: + MLIR, + TracedUtils, + Ops, + TracedRArray, + Compiler, + OrderedIdDict, + make_tracer, + TracedToTypes, + TracedTrack, + TracedType, + TracedSetPath import ..Reactant: promote_to """ - process_probprog_function(f, args_with_rng, op_name) + process_probprog_function(f, args, op_name) + +Note: by convention `args` must have the RNG state as the first argument. This function handles the probprog argument convention where: - **Index 1**: RNG state @@ -12,58 +25,86 @@ This function handles the probprog argument convention where: This wrapper ensures the RNG state is threaded through as the first result, followed by the actual function results. """ -function process_probprog_function(f, args_with_rng, op_name) - argprefix = gensym(op_name * "arg") - resprefix = gensym(op_name * "result") - resargprefix = gensym(op_name * "resarg") - - wrapper_fn = (all_args...) -> begin - res = f(all_args...) - (all_args[1], (res isa Tuple ? res : (res,))...) +function process_probprog_function(f, args, op_name, with_rng=true) + seen = OrderedIdDict() + cache_key = [] + make_tracer(seen, (f, args...), cache_key, TracedToTypes) + cache = Compiler.callcache() + + if haskey(cache, cache_key) + (; f_name, mlir_result_types, traced_result, mutated_args, linear_results, fnwrapped, argprefix, resprefix, resargprefix) = cache[cache_key] + else + f_name = String(gensym(Symbol(f))) + argprefix::Symbol = gensym(op_name * "arg") + resprefix::Symbol = gensym(op_name * "result") + resargprefix::Symbol = gensym(op_name * "resarg") + + wrapper_fn = if !with_rng + f + else + (all_args...) -> begin + res = f(all_args...) + (all_args[1], (res isa Tuple ? res : (res,))...) + end + end + + temp = TracedUtils.make_mlir_fn( + wrapper_fn, + args, + (), + f_name, + false; + do_transpose=false, + args_in_result=:result, + argprefix, + resprefix, + resargprefix, + ) + + (; traced_result, ret, mutated_args, linear_results, fnwrapped) = temp + mlir_result_types = [ + MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret) + ] + cache[cache_key] = (; + f_name, + mlir_result_types, + traced_result, + mutated_args, + linear_results, + fnwrapped, + argprefix, + resprefix, + resargprefix, + ) + end + + seen_cache = OrderedIdDict() + make_tracer(seen_cache, fnwrapped ? (f, args) : args, (), TracedTrack; toscalar=false) + linear_args = [] + mlir_caller_args = MLIR.IR.Value[] + for (_, v) in seen_cache + v isa TracedType || continue + push!(linear_args, v) + push!(mlir_caller_args, v.mlir_data) + v.paths = v.paths[1:(end - 1)] end - mlir_fn_res = TracedUtils.make_mlir_fn( - wrapper_fn, - args_with_rng, - (), - string(f), - false; - do_transpose=false, - args_in_result=:result, + return (; + f_name, + linear_args, + mlir_caller_args, + mlir_result_types, + traced_result, + linear_results, + fnwrapped, argprefix, resprefix, resargprefix, ) - - return mlir_fn_res, argprefix, resprefix, resargprefix end """ - process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) - -This function handles the probprog argument convention where: -- **Index 1**: RNG state -- **Index 2**: Function `f` (when `fnwrap` is true) -- **Index 3+**: Other arguments -""" -function process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) - inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 2 && fnwrap - TracedUtils.push_val!(inputs, f, path[3:end]) - else - if fnwrap && idx > 1 - idx -= 1 - end - TracedUtils.push_val!(inputs, args_with_rng[idx], path[3:end]) - end - end - return inputs -end - -""" - process_probprog_outputs(op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix, start_idx=0, rng_only=false) + process_probprog_outputs(op, linear_results, traced_result, f, args, fnwrapped, resprefix, argprefix, offset=0, rng_only=false) This function handles the probprog argument convention where: - **Index 1**: RNG state @@ -72,57 +113,62 @@ This function handles the probprog argument convention where: When setting results, the function checks: 1. If result path matches `resprefix`, store in `result` -2. If result path matches `argprefix`, store in `args_with_rng` (adjust indices for wrapped function) +2. If result path matches `argprefix`, store in `args` (adjust indices for wrapped function) -`start_idx` varies depending on the ProbProg operation: +`offset` varies depending on the ProbProg operation: - `sample` and `untraced_call` return only function outputs: - Use `start_idx=0`: `linear_results[i]` corresponds to `op.result[i]` + Use `offset=0`: `linear_results[i]` corresponds to `op.result[i]` - `simulate` and `generate` return trace, weight, then outputs: - Use `start_idx=2`: `linear_results[i]` corresponds to `op.result[i+2]` + Use `offset=2`: `linear_results[i]` corresponds to `op.result[i+2]` - `mh` and `regenerate` return trace, accepted/weight, rng_state (no model outputs): - Use `start_idx=2, rng_only=true`: only process first result (rng_state) + Use `offset=2, rng_only=true`: only process first result (rng_state) `rng_only`: When true, only process the first result (RNG state), skipping model outputs """ function process_probprog_outputs( op, linear_results, - result, + traced_result, f, - args_with_rng, - fnwrap, + args, + fnwrapped, resprefix, argprefix, - start_idx=0, + offset=0, rng_only=false, ) + seen_results = OrderedIdDict() + traced_result = make_tracer( + seen_results, traced_result, (), TracedSetPath; toscalar=false + ) + num_to_process = rng_only ? 1 : length(linear_results) for i in 1:num_to_process res = linear_results[i] - resv = MLIR.IR.result(op, i + start_idx) - - if TracedUtils.has_idx(res, resprefix) - path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], resv) - end + resv = MLIR.IR.result(op, i + offset) - if TracedUtils.has_idx(res, argprefix) - idx, path = TracedUtils.get_argidx(res, argprefix) - if fnwrap && idx == 2 - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap && idx > 2 - idx -= 1 + for path in res.paths + if length(path) == 0 + continue + end + if path[1] == resprefix + TracedUtils.set!(traced_result, path[2:end], resv) + elseif path[1] == argprefix + idx = path[2]::Int + if fnwrapped && idx == 2 + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrapped && idx > 2 + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) end - TracedUtils.set!(args_with_rng[idx], path[3:end], resv) end end - - if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) - TracedUtils.set!(res, (), resv) - end end + + return traced_result end to_trace_tensor(t::ProbProgTrace) = promote_to(TracedRArray{UInt64,0}, t) diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index b7889c46dd..410fa716a1 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -3,14 +3,19 @@ using Reactant: ProbProg, ReactantRNG normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + function one_sample(rng, μ, σ, shape) - _, s = ProbProg.sample(rng, normal, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; logpdf=normal_logpdf) return s end function two_samples(rng, μ, σ, shape) - _ = ProbProg.sample(rng, normal, μ, σ, shape) - _, t = ProbProg.sample(rng, normal, μ, σ, shape) + _ = ProbProg.sample(rng, normal, μ, σ, shape; logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, μ, σ, shape; logpdf=normal_logpdf) return t end @@ -28,7 +33,9 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - code = @code_hlo optimize = false ProbProg.sample(rng, normal, μ, σ, shape) + code = @code_hlo optimize = false ProbProg.sample( + rng, normal, μ, σ, shape; logpdf=normal_logpdf + ) @test contains(repr(code), "enzyme.sample") end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 3be45bc256..a1b0cec3d8 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -26,6 +26,12 @@ function model2(rng, μ, σ, shape) return t end +function model3(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:s) + _, t = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:t) + return s, t +end + @testset "Simulate" begin @testset "hlo" begin shape = (3, 3, 3) @@ -104,6 +110,39 @@ end @test haskey(trace.subtraces[:s].choices, :b) @test haskey(trace.subtraces[:t].choices, :a) @test haskey(trace.subtraces[:t].choices, :b) + @test trace.subtraces[:s].choices[:a][1] !== trace.subtraces[:t].choices[:a][1] + @test trace.subtraces[:s].choices[:b][1] !== trace.subtraces[:t].choices[:b][1] + + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + + @test trace.weight isa Float64 + + @test trace.weight ≈ trace.subtraces[:s].weight + trace.subtraces[:t].weight + end + + @testset "submodel_subtraces" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate_(rng, model3, μ, σ, shape) + + @test size(trace.retval[1]) == shape + + @test length(trace.choices) == 2 + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + + @test length(trace.subtraces) == 2 + @test haskey(trace.subtraces[:s].choices, :a) + @test haskey(trace.subtraces[:s].choices, :b) + @test haskey(trace.subtraces[:t].choices, :a) + @test haskey(trace.subtraces[:t].choices, :b) + @test trace.subtraces[:s].choices[:a][1] !== trace.subtraces[:t].choices[:a][1] + @test trace.subtraces[:s].choices[:b][1] !== trace.subtraces[:t].choices[:b][1] @test size(trace.choices[:s][1]) == shape @test size(trace.choices[:t][1]) == shape From 703f572a8b36441b6a82ccc4a0e732cd60f66daf Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 30 Oct 2025 22:08:34 -0500 Subject: [PATCH 04/10] automated tracing of ProbProgTrace and Constraint structs --- src/Reactant.jl | 4 ++- src/probprog/Modeling.jl | 25 +++++-------- src/probprog/Types.jl | 7 ++++ src/probprog/Utils.jl | 75 ++++++++++++++++++++++++++++----------- test/probprog/generate.jl | 27 +++++--------- test/probprog/hmc.jl | 21 ++++------- test/probprog/mh.jl | 35 +++++++----------- 7 files changed, 102 insertions(+), 92 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 0ef9af5145..25228a441b 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -246,11 +246,13 @@ include("Tracing.jl") include("Compiler.jl") include("Overlay.jl") -include("probprog/ProbProg.jl") # Serialization include("serialization/Serialization.jl") +# ProbProg +include("probprog/ProbProg.jl") + using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile export ConcreteRArray, ConcreteRNumber, diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index 85e2f88675..6755cfb259 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -92,7 +92,7 @@ function simulate_(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where seed_buffer = only(rng.seed.data).buffer GC.@preserve seed_buffer begin t, _, _ = compiled_fn(rng, f, args...) - trace = from_trace_tensor(t) + trace = ProbProgTrace(t) end return trace, trace.weight @@ -146,27 +146,20 @@ end # Gen-like helper function. function generate_( - rng::AbstractRNG, - f::Function, - args::Vararg{Any,Nargs}; - constraint::Constraint=Constraint(), + rng::AbstractRNG, constraint::Constraint, f::Function, args::Vararg{Any,Nargs} ) where {Nargs} trace = nothing - constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) - constrained_addresses = extract_addresses(constraint) - function wrapper_fn(rng, constraint_ptr, args...) - return generate(rng, f, args...; constraint_ptr, constrained_addresses) - end - - compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...) + compiled_fn = @compile optimize = :probprog generate( + rng, constraint, f, args...; constrained_addresses + ) seed_buffer = only(rng.seed.data).buffer GC.@preserve seed_buffer constraint begin - t, _, _ = compiled_fn(rng, constraint_ptr, args...) - trace = from_trace_tensor(t) + t, _, _ = compiled_fn(rng, constraint, f, args...) + trace = ProbProgTrace(t) end return trace, trace.weight @@ -174,9 +167,9 @@ end function generate( rng::AbstractRNG, + constraint, f::Function, args::Vararg{Any,Nargs}; - constraint_ptr::TracedRNumber, constrained_addresses::Set{Address}, ) where {Nargs} args = (rng, args...) @@ -193,7 +186,7 @@ function generate( constraint_val = MLIR.IR.result( MLIR.Dialects.builtin.unrealized_conversion_cast( - [TracedUtils.get_mlir_data(constraint_ptr)]; outputs=[constraint_ty] + [TracedUtils.get_mlir_data(constraint)]; outputs=[constraint_ty] ), 1, ) diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl index 98f189d9a0..c1c94d7031 100644 --- a/src/probprog/Types.jl +++ b/src/probprog/Types.jl @@ -1,4 +1,5 @@ using Base: ReentrantLock +using ..Reactant: AbstractConcreteNumber, AbstractConcreteArray mutable struct ProbProgTrace choices::Dict{Symbol,Any} @@ -9,6 +10,9 @@ mutable struct ProbProgTrace function ProbProgTrace() return new(Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}()) end + function ProbProgTrace(x::Union{AbstractConcreteNumber,AbstractConcreteArray}) + return convert(ProbProgTrace, x) + end end struct Address @@ -42,6 +46,9 @@ mutable struct Constraint <: AbstractDict{Address,Any} Constraint() = new(Dict{Address,Any}()) Constraint(d::Dict{Address,Any}) = new(d) + function Constraint(x::Union{AbstractConcreteNumber,AbstractConcreteArray}) + return convert(Constraint, x) + end end Base.getindex(c::Constraint, k::Address) = c.dict[k] diff --git a/src/probprog/Utils.jl b/src/probprog/Utils.jl index fa05c2d812..867eb97f46 100644 --- a/src/probprog/Utils.jl +++ b/src/probprog/Utils.jl @@ -3,14 +3,20 @@ using ..Reactant: TracedUtils, Ops, TracedRArray, + TracedRNumber, Compiler, OrderedIdDict, - make_tracer, TracedToTypes, - TracedTrack, TracedType, - TracedSetPath -import ..Reactant: promote_to + TracedTrack, + TracedSetPath, + ConcreteToTraced, + AbstractConcreteArray, + XLA, + Sharding, + to_number +import ..Reactant: promote_to, make_tracer +import ..Compiler: donate_argument! """ process_probprog_function(f, args, op_name) @@ -171,30 +177,59 @@ function process_probprog_outputs( return traced_result end -to_trace_tensor(t::ProbProgTrace) = promote_to(TracedRArray{UInt64,0}, t) +function promote_to(::Type{TracedRArray{UInt64,0}}, t::Union{ProbProgTrace,Constraint}) + return Ops.fill(reinterpret(UInt64, pointer_from_objref(t)), Int64[]) +end -function from_trace_tensor(trace_tensor) - while !isready(trace_tensor) +function Base.convert( + ::Type{T}, x::AbstractConcreteArray +) where {T<:Union{ProbProgTrace,Constraint}} + while !isready(x) yield() end - return unsafe_pointer_to_objref(Ptr{Any}(Array(trace_tensor)[1]))::ProbProgTrace + return unsafe_pointer_to_objref(Ptr{Any}(collect(x)[1]))::T end -function promote_to(::Type{TracedRArray{UInt64,0}}, t::ProbProgTrace) - ptr = reinterpret(UInt64, pointer_from_objref(t)) - return Ops.fill(ptr, Int64[]) +function Base.convert( + ::Type{T}, x::AbstractConcreteNumber +) where {T<:Union{ProbProgTrace,Constraint}} + while !isready(x) + yield() + end + return unsafe_pointer_to_objref(Ptr{Any}(to_number(x)))::T end -to_constraint_tensor(c::Constraint) = promote_to(TracedRArray{UInt64,0}, c) - -function from_constraint_tensor(constraint_tensor) - while !isready(constraint_tensor) - yield() +function Base.getproperty(t::Union{ProbProgTrace,Constraint}, s::Symbol) + if s === :data + return ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(t))).data + else + return getfield(t, s) end - return unsafe_pointer_to_objref(Ptr{Any}(Array(constraint_tensor)[1]))::Constraint end -function promote_to(::Type{TracedRArray{UInt64,0}}, c::Constraint) - ptr = reinterpret(UInt64, pointer_from_objref(c)) - return Ops.fill(ptr, Int64[]) +function donate_argument!( + ::Any, ::Union{ProbProgTrace,Constraint}, ::Int, ::Any, ::Any +) + return nothing +end + +Base.@nospecializeinfer function make_tracer( + seen, + @nospecialize(prev::Union{ProbProgTrace,Constraint}), + @nospecialize(path), + mode; + @nospecialize(sharding = Sharding.NoSharding()), + kwargs..., +) + if mode == ConcreteToTraced + haskey(seen, prev) && return seen[prev]::TracedRNumber{UInt64} + result = TracedRNumber{UInt64}((path,), nothing) + seen[prev] = result + return result + elseif mode == TracedToTypes + push!(path, typeof(prev)) + return nothing + else + error("Unsupported mode for $(typeof(prev)): $mode") + end end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index f5fa4fea38..403f2e67e3 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -34,7 +34,7 @@ end rng = ReactantRNG(seed) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace, weight = ProbProg.generate_(rng, model, μ, σ, shape) + trace, weight = ProbProg.generate_(rng, ProbProg.Constraint(), model, μ, σ, shape) @test mean(trace.retval[1]) ≈ 0.0 atol = 0.05 rtol = 0.05 end @@ -47,7 +47,7 @@ end constraint = ProbProg.Constraint(:s => (fill(0.1, shape),)) - trace, weight = ProbProg.generate_(rng, model, μ, σ, shape; constraint) + trace, weight = ProbProg.generate_(rng, constraint, model, μ, σ, shape) @test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1] @@ -72,7 +72,7 @@ end :u => :y => (fill(0.3, shape),), ) - trace, weight = ProbProg.generate_(rng, nested_model, μ, σ, shape; constraint) + trace, weight = ProbProg.generate_(rng, constraint, nested_model, μ, σ, shape) @test trace.choices[:s][1] == fill(0.1, shape) @test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape) @@ -108,33 +108,24 @@ end constrained_addresses = ProbProg.extract_addresses(constraint1) - constraint_ptr1 = Reactant.ConcreteRNumber( - reinterpret(UInt64, pointer_from_objref(constraint1)) + compiled_fn = @compile optimize = :probprog ProbProg.generate( + rng, constraint1, model, μ, σ, shape; constrained_addresses ) - wrapper_fn(rng, constraint_ptr, μ, σ) = ProbProg.generate( - rng, model, μ, σ, shape; constraint_ptr, constrained_addresses - ) - - compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr1, μ, σ) - trace1 = nothing seed_buffer = only(rng.seed.data).buffer GC.@preserve seed_buffer constraint1 begin - trace1, _ = compiled_fn(rng, constraint_ptr1, μ, σ) - trace1 = ProbProg.from_trace_tensor(trace1) + trace1, _ = compiled_fn(rng, constraint1, model, μ, σ, shape) + trace1 = ProbProg.ProbProgTrace(trace1) end constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),)) - constraint_ptr2 = Reactant.ConcreteRNumber( - reinterpret(UInt64, pointer_from_objref(constraint2)) - ) trace2 = nothing seed_buffer = only(rng.seed.data).buffer GC.@preserve seed_buffer constraint2 begin - trace2, _ = compiled_fn(rng, constraint_ptr2, μ, σ) - trace2 = ProbProg.from_trace_tensor(trace2) + trace2, _ = compiled_fn(rng, constraint2, model, μ, σ, shape) + trace2 = ProbProg.ProbProgTrace(trace2) end @test trace1.choices[:s][1] != trace2.choices[:s][1] diff --git a/test/probprog/hmc.jl b/test/probprog/hmc.jl index 03c3394a40..c88a92c7d6 100644 --- a/test/probprog/hmc.jl +++ b/test/probprog/hmc.jl @@ -36,16 +36,10 @@ function hmc_program( num_steps, mass, initial_momentum, - constraint_ptr, + constraint, constrained_addresses, ) - t, _, _ = ProbProg.generate( - rng, - model, - xs; - constraint_ptr=constraint_ptr, - constrained_addresses=constrained_addresses, - ) + t, _, _ = ProbProg.generate(rng, constraint, model, xs; constrained_addresses) t, accepted, _ = ProbProg.hmc( rng, @@ -73,7 +67,6 @@ end :param_a => ([0.0],), :param_b => ([0.0],), :ys_a => (ys_a,), :ys_b => (ys_b,) ) constrained_addresses = ProbProg.extract_addresses(obs) - constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(obs))) step_size = ConcreteRNumber(0.001) num_steps_compile = ConcreteRNumber(1000) @@ -89,7 +82,7 @@ end num_steps_compile, mass, initial_momentum, - constraint_ptr, + obs, constrained_addresses, ) @test contains(repr(code), "enzyme_probprog_get_flattened_samples_from_trace") @@ -106,7 +99,7 @@ end num_steps_compile, mass, initial_momentum, - constraint_ptr, + obs, constrained_addresses, ) end @@ -116,7 +109,7 @@ end trace = nothing GC.@preserve seed_buffer obs begin run_time_s = @elapsed begin - trace_ptr, _ = compiled_fn( + trace, _ = compiled_fn( rng, model, xs, @@ -124,10 +117,10 @@ end num_steps_run, mass, initial_momentum, - constraint_ptr, + obs, constrained_addresses, ) - trace = ProbProg.from_trace_tensor(trace_ptr) + trace = ProbProg.ProbProgTrace(trace) end println("HMC run time: $(round(run_time_s * 1000, digits=2)) ms") end diff --git a/test/probprog/mh.jl b/test/probprog/mh.jl index b05021827d..f331d8abd1 100644 --- a/test/probprog/mh.jl +++ b/test/probprog/mh.jl @@ -31,30 +31,22 @@ function model(rng, xs) return ys end -function mh_program(rng, model, xs, num_iters, constraint_ptr, constrained_addresses) +function mh_program(rng, model, xs, num_iters, constraint, constrained_addresses) init_trace, _, _ = ProbProg.generate( - rng, - model, - xs; - constraint_ptr=constraint_ptr, - constrained_addresses=constrained_addresses, + rng, constraint, model, xs; constrained_addresses=constrained_addresses ) - trace_ptr = init_trace + trace = init_trace @trace for _ in 1:num_iters - trace_ptr, _ = ProbProg.mh( - rng, trace_ptr, model, xs; selection=ProbProg.select(ProbProg.Address(:slope)) + trace, _ = ProbProg.mh( + rng, trace, model, xs; selection=ProbProg.select(ProbProg.Address(:slope)) ) - trace_ptr, _ = ProbProg.mh( - rng, - trace_ptr, - model, - xs; - selection=ProbProg.select(ProbProg.Address(:intercept)), + trace, _ = ProbProg.mh( + rng, trace, model, xs; selection=ProbProg.select(ProbProg.Address(:intercept)) ) end - return trace_ptr + return trace end @testset "linear_regression" begin @@ -81,27 +73,24 @@ end obs = ProbProg.Constraint(:ys => (ys,)) num_iters = ConcreteRNumber(10000) constrained_addresses = ProbProg.extract_addresses(obs) - constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(obs))) code = @code_hlo optimize = :probprog mh_program( - rng, model, xs, 10000, constraint_ptr, constrained_addresses + rng, model, xs, num_iters, obs, constrained_addresses ) @test contains(repr(code), "enzyme_probprog_get_sample_from_trace") @test contains(repr(code), "enzyme_probprog_get_weight_from_trace") @test !contains(repr(code), "enzyme.mh") compiled_fn = @compile optimize = :probprog mh_program( - rng, model, xs, num_iters, constraint_ptr, constrained_addresses + rng, model, xs, num_iters, obs, constrained_addresses ) trace = nothing seed_buffer = only(rng.seed.data).buffer num_iters = ConcreteRNumber(1000) GC.@preserve seed_buffer obs begin - trace_ptr = compiled_fn( - rng, model, xs, num_iters, constraint_ptr, constrained_addresses - ) - trace = ProbProg.from_trace_tensor(trace_ptr) + trace = compiled_fn(rng, model, xs, num_iters, obs, constrained_addresses) + trace = ProbProg.ProbProgTrace(trace) end slope = only(trace.choices[:slope])[1] From 6ef12c2f5ff1fc9827d136527ccb7588535e6b0a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 30 Oct 2025 22:31:16 -0500 Subject: [PATCH 05/10] clean up --- src/probprog/HMC.jl | 25 +++++++------------------ src/probprog/MH.jl | 26 +++++++------------------- 2 files changed, 14 insertions(+), 37 deletions(-) diff --git a/src/probprog/HMC.jl b/src/probprog/HMC.jl index 501448d92b..a009aa33d7 100644 --- a/src/probprog/HMC.jl +++ b/src/probprog/HMC.jl @@ -2,7 +2,7 @@ using ..Reactant: ConcreteRNumber, TracedRArray function hmc( rng::AbstractRNG, - original_trace::Union{ProbProgTrace,TracedRArray{UInt64,0}}, + original_trace, f::Function, args::Vararg{Any,Nargs}; selection::Selection, @@ -21,23 +21,12 @@ function hmc( MLIR.IR.context()::MLIR.API.MlirContext )::MLIR.IR.Type - trace_val = if original_trace isa TracedRArray{UInt64,0} - MLIR.IR.result( - MLIR.Dialects.builtin.unrealized_conversion_cast( - [original_trace.mlir_data]; outputs=[trace_ty] - ), - 1, - ) - else - # First iteration: promote a ProbProgTrace to tensor - promoted = to_trace_tensor(original_trace) - MLIR.IR.result( - MLIR.Dialects.builtin.unrealized_conversion_cast( - [TracedUtils.get_mlir_data(promoted)]; outputs=[trace_ty] - ), - 1, - ) - end + trace_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(original_trace)]; outputs=[trace_ty] + ), + 1, + ) selection_attr = MLIR.IR.Attribute[] for address in selection diff --git a/src/probprog/MH.jl b/src/probprog/MH.jl index cbe51daa17..f1d6b9f693 100644 --- a/src/probprog/MH.jl +++ b/src/probprog/MH.jl @@ -2,7 +2,7 @@ using ..Reactant: ConcreteRNumber, TracedRArray function mh( rng::AbstractRNG, - original_trace::Union{ProbProgTrace,TracedRArray{UInt64,0}}, + original_trace, f::Function, args::Vararg{Any,Nargs}; selection::Selection, @@ -17,24 +17,12 @@ function mh( MLIR.IR.context()::MLIR.API.MlirContext )::MLIR.IR.Type - if original_trace isa TracedRArray{UInt64,0} - # Use MLIR data from previous iteration - trace_val = MLIR.IR.result( - MLIR.Dialects.builtin.unrealized_conversion_cast( - [original_trace.mlir_data]; outputs=[trace_ty] - ), - 1, - ) - else - # First iteration: create constant from pointer - promoted = to_trace_tensor(original_trace) - trace_val = MLIR.IR.result( - MLIR.Dialects.builtin.unrealized_conversion_cast( - [TracedUtils.get_mlir_data(promoted)]; outputs=[trace_ty] - ), - 1, - ) - end + trace_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(original_trace)]; outputs=[trace_ty] + ), + 1, + ) selection_attr = MLIR.IR.Attribute[] for address in selection From 429e05d8216d488bcfd2e422442e15c78b99647b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 30 Oct 2025 22:42:59 -0500 Subject: [PATCH 06/10] format --- src/probprog/Utils.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/probprog/Utils.jl b/src/probprog/Utils.jl index 867eb97f46..f5223180dc 100644 --- a/src/probprog/Utils.jl +++ b/src/probprog/Utils.jl @@ -207,9 +207,7 @@ function Base.getproperty(t::Union{ProbProgTrace,Constraint}, s::Symbol) end end -function donate_argument!( - ::Any, ::Union{ProbProgTrace,Constraint}, ::Int, ::Any, ::Any -) +function donate_argument!(::Any, ::Union{ProbProgTrace,Constraint}, ::Int, ::Any, ::Any) return nothing end From f4d2a181e6e027c804f23ba9d94b4a7a6a5b0bab Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sun, 2 Nov 2025 01:38:06 -0500 Subject: [PATCH 07/10] clean up --- test/probprog/hmc.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/probprog/hmc.jl b/test/probprog/hmc.jl index c88a92c7d6..bd4ea56216 100644 --- a/test/probprog/hmc.jl +++ b/test/probprog/hmc.jl @@ -87,7 +87,6 @@ end ) @test contains(repr(code), "enzyme_probprog_get_flattened_samples_from_trace") @test contains(repr(code), "enzyme_probprog_get_weight_from_trace") - @test !contains(repr(code), "enzyme.mh") @test !contains(repr(code), "enzyme.mcmc") compile_time_s = @elapsed begin From da2ad9bd050751dbe971218c3fc64a5c58e641c3 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sat, 8 Nov 2025 22:42:58 -0600 Subject: [PATCH 08/10] CI --- src/probprog/HMC.jl | 2 +- src/probprog/MH.jl | 2 +- src/probprog/Modeling.jl | 4 ++-- src/probprog/ProbProg.jl | 2 -- src/probprog/Utils.jl | 26 +++++--------------------- 5 files changed, 9 insertions(+), 27 deletions(-) diff --git a/src/probprog/HMC.jl b/src/probprog/HMC.jl index a009aa33d7..d523f5827e 100644 --- a/src/probprog/HMC.jl +++ b/src/probprog/HMC.jl @@ -1,4 +1,4 @@ -using ..Reactant: ConcreteRNumber, TracedRArray +using ..Reactant: TracedRArray function hmc( rng::AbstractRNG, diff --git a/src/probprog/MH.jl b/src/probprog/MH.jl index f1d6b9f693..2c2c93c8a8 100644 --- a/src/probprog/MH.jl +++ b/src/probprog/MH.jl @@ -1,4 +1,4 @@ -using ..Reactant: ConcreteRNumber, TracedRArray +using ..Reactant: TracedRArray function mh( rng::AbstractRNG, diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index 6755cfb259..3dc1b11b16 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -1,5 +1,5 @@ -using ..Reactant: MLIR, TracedUtils, AbstractRNG, TracedRArray, ConcreteRNumber -using ..Compiler: @jit, @compile +using ..Reactant: MLIR, TracedUtils, AbstractRNG, TracedRArray +using ..Compiler: @compile include("Utils.jl") diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl index 23de59676f..2c9c043538 100644 --- a/src/probprog/ProbProg.jl +++ b/src/probprog/ProbProg.jl @@ -16,8 +16,6 @@ export ProbProgTrace, Constraint, Selection, Address # Utility functions. export get_choices, select -export to_trace_tensor, from_trace_tensor -export to_constraint_tensor, from_constraint_tensor # Core MLIR ops. export sample, untraced_call, simulate, generate, mh, hmc diff --git a/src/probprog/Utils.jl b/src/probprog/Utils.jl index f5223180dc..91dfdc2f67 100644 --- a/src/probprog/Utils.jl +++ b/src/probprog/Utils.jl @@ -12,7 +12,6 @@ using ..Reactant: TracedSetPath, ConcreteToTraced, AbstractConcreteArray, - XLA, Sharding, to_number import ..Reactant: promote_to, make_tracer @@ -21,15 +20,8 @@ import ..Compiler: donate_argument! """ process_probprog_function(f, args, op_name) -Note: by convention `args` must have the RNG state as the first argument. - -This function handles the probprog argument convention where: -- **Index 1**: RNG state -- **Index 2**: Function `f` (when wrapped) -- **Index 3+**: Remaining arguments - -This wrapper ensures the RNG state is threaded through as the first result, -followed by the actual function results. +By convention `args` must have the RNG state as the first argument. +Ensures the RNG state is threaded through as the first result, followed by the actual function results. """ function process_probprog_function(f, args, op_name, with_rng=true) seen = OrderedIdDict() @@ -114,22 +106,14 @@ end This function handles the probprog argument convention where: - **Index 1**: RNG state -- **Index 2**: Function `f` (when `fnwrap` is true) +- **Index 2**: Function `f` (when `fnwrapped` is true) - **Index 3+**: Other arguments -When setting results, the function checks: -1. If result path matches `resprefix`, store in `result` -2. If result path matches `argprefix`, store in `args` (adjust indices for wrapped function) - -`offset` varies depending on the ProbProg operation: -- `sample` and `untraced_call` return only function outputs: - Use `offset=0`: `linear_results[i]` corresponds to `op.result[i]` +`offset` and `rng_only` vary depending on the ProbProg operation, e.g.: - `simulate` and `generate` return trace, weight, then outputs: Use `offset=2`: `linear_results[i]` corresponds to `op.result[i+2]` -- `mh` and `regenerate` return trace, accepted/weight, rng_state (no model outputs): +- `mh` and `regenerate` return trace, accepted/weight, new rng_state: Use `offset=2, rng_only=true`: only process first result (rng_state) - -`rng_only`: When true, only process the first result (RNG state), skipping model outputs """ function process_probprog_outputs( op, From 1d26fbb4106fad535f9ac5689e4b604af6c94c40 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sat, 8 Nov 2025 23:13:55 -0600 Subject: [PATCH 09/10] minor --- test/probprog/hmc.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/probprog/hmc.jl b/test/probprog/hmc.jl index bd4ea56216..eb5b967633 100644 --- a/test/probprog/hmc.jl +++ b/test/probprog/hmc.jl @@ -47,10 +47,10 @@ function hmc_program( model, xs; selection=ProbProg.select(ProbProg.Address(:param_a), ProbProg.Address(:param_b)), - mass=mass, - step_size=step_size, - num_steps=num_steps, - initial_momentum=initial_momentum, + mass, + step_size, + num_steps, + initial_momentum, ) return t, accepted From b4aa321da16a8c822230757f29045353b3ad31bd Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sat, 8 Nov 2025 23:51:11 -0600 Subject: [PATCH 10/10] ci --- src/probprog/Utils.jl | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/probprog/Utils.jl b/src/probprog/Utils.jl index 91dfdc2f67..2907ba312d 100644 --- a/src/probprog/Utils.jl +++ b/src/probprog/Utils.jl @@ -17,12 +17,6 @@ using ..Reactant: import ..Reactant: promote_to, make_tracer import ..Compiler: donate_argument! -""" - process_probprog_function(f, args, op_name) - -By convention `args` must have the RNG state as the first argument. -Ensures the RNG state is threaded through as the first result, followed by the actual function results. -""" function process_probprog_function(f, args, op_name, with_rng=true) seen = OrderedIdDict() cache_key = [] @@ -101,20 +95,6 @@ function process_probprog_function(f, args, op_name, with_rng=true) ) end -""" - process_probprog_outputs(op, linear_results, traced_result, f, args, fnwrapped, resprefix, argprefix, offset=0, rng_only=false) - -This function handles the probprog argument convention where: -- **Index 1**: RNG state -- **Index 2**: Function `f` (when `fnwrapped` is true) -- **Index 3+**: Other arguments - -`offset` and `rng_only` vary depending on the ProbProg operation, e.g.: -- `simulate` and `generate` return trace, weight, then outputs: - Use `offset=2`: `linear_results[i]` corresponds to `op.result[i+2]` -- `mh` and `regenerate` return trace, accepted/weight, new rng_state: - Use `offset=2, rng_only=true`: only process first result (rng_state) -""" function process_probprog_outputs( op, linear_results,