diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index e8cac78be6..925c357e1a 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -229,6 +229,7 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :probprog, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index bea2f19472..9a26888880 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1318,6 +1318,7 @@ end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize,arith-raise{stablehlo=true}\"}" +const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true}\"}" function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true) pm = MLIR.IR.PassManager() @@ -1905,6 +1906,71 @@ function compile_mlir!( ), "no_enzyme", ) + elseif compile_options.optimization_passes === :probprog + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + probprog_pass, + "lower-probprog-to-stablehlo{backend=$backend}", + "outline-enzyme-regions", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + lower_enzymexla_linalg_pass, + "lower-probprog-trace-ops{backend=$backend}", + jit, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + probprog_pass, + "lower-probprog-to-stablehlo{backend=$backend}", + "outline-enzyme-regions", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + kern, + raise_passes, + lower_enzymexla_linalg_pass, + "lower-probprog-trace-ops{backend=$backend}", + jit, + ] + end, + ",", + ), + "probprog", + ) elseif compile_options.optimization_passes === :only_enzyme run_pass_pipeline!( mod, diff --git a/src/Reactant.jl b/src/Reactant.jl index 69cee2b3b0..25228a441b 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -250,6 +250,9 @@ include("Overlay.jl") # Serialization include("serialization/Serialization.jl") +# ProbProg +include("probprog/ProbProg.jl") + using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile export ConcreteRArray, ConcreteRNumber, diff --git a/src/Types.jl b/src/Types.jl index 6d932ecc0c..c00ef99438 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -241,6 +241,7 @@ function ConcretePJRTArray( end Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) +Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data) XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data)) @@ -420,6 +421,7 @@ function ConcreteIFRTArray( end Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data) +Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data) XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) return XLA.device(x.data) diff --git a/src/probprog/Display.jl b/src/probprog/Display.jl new file mode 100644 index 0000000000..a81992eb71 --- /dev/null +++ b/src/probprog/Display.jl @@ -0,0 +1,87 @@ +# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 +function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + + indent_vert_str = join(indent_vert) + indent_str = join(indent) + indent_last_str = join(indent_last) + + sorted_choices = sort(collect(trace.choices); by=x -> x[1]) + n = length(sorted_choices) + + if trace.retval !== nothing + n += 1 + end + + if trace.weight !== nothing + n += 1 + end + + cur = 1 + + if trace.retval !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n") + cur += 1 + end + + if trace.weight !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n") + cur += 1 + end + + for (key, value) in sorted_choices + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + + sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1]) + n += length(sorted_subtraces) + + for (key, subtrace) in sorted_subtraces + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n") + _show_pretty( + io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1) + ) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) + println(io, "ProbProgTrace:") + if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing + println(io, " (empty)") + else + _show_pretty(io, trace, 0, ()) + end +end + +function Base.show(io::IO, trace::ProbProgTrace) + if get(io, :compact, false) + choices_count = length(trace.choices) + has_retval = trace.retval !== nothing + print(io, "ProbProgTrace($(choices_count) choices") + if has_retval + print(io, ", retval=$(trace.retval), weight=$(trace.weight)") + end + print(io, ")") + else + show(io, MIME"text/plain"(), trace) + end +end diff --git a/src/probprog/FFI.jl b/src/probprog/FFI.jl new file mode 100644 index 0000000000..2f40390727 --- /dev/null +++ b/src/probprog/FFI.jl @@ -0,0 +1,768 @@ +using ..Reactant: MLIR, Profiler + +function initTrace(trace_ptr_ptr::Ptr{Ptr{Any}}) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.initTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + tr = ProbProgTrace() + _keepalive!(tr) + + unsafe_store!(trace_ptr_ptr, pointer_from_objref(tr)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addSampleToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_outputs_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addSampleToTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_outputs = unsafe_load(num_outputs_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_outputs) + width_array = unsafe_wrap(Array, width_array, num_outputs) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_outputs) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_outputs) + + vals = Any[] + for i in 1:num_outputs + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(sample_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)))) + end + end + + trace.choices[symbol] = tuple(vals...) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addSubtrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subtrace_ptr_ptr::Ptr{Ptr{Any}}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addSubtrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + subtrace = unsafe_pointer_to_objref(unsafe_load(subtrace_ptr_ptr))::ProbProgTrace + + trace.subtraces[symbol] = subtrace + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addWeightToTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addWeightToTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + trace.weight = unsafe_load(Ptr{Float64}(weight_ptr)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addRetvalToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + retval_ptr_array::Ptr{Ptr{Any}}, + num_results_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addRetvalToTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + num_results = unsafe_load(num_results_ptr) + + if num_results == 0 + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + ndims_array = unsafe_wrap(Array, ndims_array, num_results) + width_array = unsafe_wrap(Array, width_array, num_results) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_results) + retval_ptr_array = unsafe_wrap(Array, retval_ptr_array, num_results) + + vals = Any[] + for i in 1:num_results + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + retval_ptr = retval_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(retval_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(retval_ptr), Tuple(shape)))) + end + end + + trace.retval = tuple(vals...) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSampleFromConstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSampleFromConstraint"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + tostore = get(constraint, Address(symbol), nothing) + + if tostore === nothing + @ccall printf( + "No constraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if julia_type != eltype(tostore[i]) + @ccall printf( + "Type mismatch in constrained sample: %s != %s\n"::Cstring, + string(julia_type)::Cstring, + string(eltype(tostore[i]))::Cstring, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if ndims == 0 + unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + + dest_size = size(dest) + src_size = size(tostore[i]) + + if dest_size != src_size + @ccall printf( + "Shape mismatch in constrained sample: expected %zd dims, got %zd\n"::Cstring, + length(dest_size)::Csize_t, + length(src_size)::Csize_t, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + copyto!(dest, tostore[i]) + end + end + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSubconstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subconstraint_ptr_ptr::Ptr{Ptr{Any}}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSubconstraint"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + + subconstraint = Constraint() + + for (key, value) in constraint + if key.path[1] == symbol + @assert isa(key, Address) "Expected Address type for constraint key" + @assert length(key.path) > 1 "Expected composite address with length > 1" + tail_address = Address(key.path[2:end]) + subconstraint[tail_address] = value + end + end + + if isempty(subconstraint) + @ccall printf( + "No subconstraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + _keepalive!(subconstraint) + unsafe_store!(subconstraint_ptr_ptr, pointer_from_objref(subconstraint)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSampleFromTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSampleFromTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + tostore = get(trace.choices, symbol, nothing) + + if tostore === nothing + @ccall printf( + "No sample found in trace for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if julia_type != eltype(tostore[i]) + @ccall printf( + "Type mismatch in trace sample: %s != %s\n"::Cstring, + string(julia_type)::Cstring, + string(eltype(tostore[i]))::Cstring, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if ndims == 0 + unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + + dest_size = size(dest) + src_size = size(tostore[i]) + + if dest_size != src_size + @ccall printf( + "Shape mismatch in trace sample: expected %zd dims, got %zd\n"::Cstring, + length(dest_size)::Csize_t, + length(src_size)::Csize_t, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + copyto!(dest, tostore[i]) + end + end + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSubtrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subtrace_ptr_ptr::Ptr{Ptr{Any}}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSubtrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + + subtrace = get(trace.subtraces, symbol, nothing) + + if subtrace === nothing + @ccall printf( + "No subtrace found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + _keepalive!(subtrace) + unsafe_store!(subtrace_ptr_ptr, pointer_from_objref(subtrace)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getWeightFromTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getWeightFromTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + unsafe_store!(Ptr{Float64}(weight_ptr), trace.weight) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getFlattenedSamplesFromTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + num_addresses_ptr::Ptr{UInt64}, + total_symbols_ptr::Ptr{UInt64}, + address_lengths_ptr::Ptr{UInt64}, + flattened_symbols_ptr::Ptr{UInt64}, + position_ptr::Ptr{Any}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getFlattenedSamplesFromTrace"::Cstring, + Profiler.TRACE_ME_LEVEL_CRITICAL::Cint, + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("No trace found\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + num_addresses = unsafe_load(num_addresses_ptr) + total_symbols = unsafe_load(total_symbols_ptr) + + address_lengths = unsafe_wrap(Array, address_lengths_ptr, num_addresses) + flattened_symbols = unsafe_wrap(Array, flattened_symbols_ptr, total_symbols) + + addresses = Vector{Vector{Symbol}}() + symbol_idx = 1 + for i in 1:num_addresses + addr_len = address_lengths[i] + address = Symbol[] + for j in 1:addr_len + symbol_ptr_value = flattened_symbols[symbol_idx] + symbol = unsafe_pointer_to_objref(Ptr{Any}(symbol_ptr_value))::Symbol + push!(address, symbol) + symbol_idx += 1 + end + push!(addresses, address) + end + + flattened_values = Float64[] + + for address in addresses + current_trace = trace + + for (idx, symbol) in enumerate(address) + if idx < length(address) + if !haskey(current_trace.subtraces, symbol) + @ccall printf( + "No subtrace found for symbol in address path: %s\n"::Cstring, + string(symbol)::Cstring, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + current_trace = current_trace.subtraces[symbol] + else + if !haskey(current_trace.choices, symbol) + @ccall printf( + "No sample found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + sample_tuple = current_trace.choices[symbol] + + for sample_val in sample_tuple + if isa(sample_val, AbstractArray) + for val in sample_val + push!(flattened_values, Float64(val)) + end + else + push!(flattened_values, Float64(sample_val)) + end + end + end + end + end + + position_array = unsafe_wrap( + Array, Ptr{Float64}(position_ptr), length(flattened_values) + ) + copyto!(position_array, flattened_values) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function dump( + value_ptr::Ptr{Any}, + label_ptr::Ptr{UInt8}, + ndims_ptr::Ptr{UInt64}, + shape_ptr::Ptr{UInt64}, + width_ptr::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.dump"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + label = unsafe_string(label_ptr) + ndims = unsafe_load(ndims_ptr) + width = unsafe_load(width_ptr) + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + @ccall printf( + "DUMP ERROR: Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + println("═══ DUMP: $label ═══") + + if ndims == 0 + value = unsafe_load(Ptr{julia_type}(value_ptr)) + println(" Scalar ($julia_type): $value") + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + value_array = unsafe_wrap(Array, Ptr{julia_type}(value_ptr), Tuple(shape)) + + println(" Shape: $(Tuple(shape))") + println(" Type: Array{$julia_type}") + println(" Values:") + + total_elements = prod(shape) + if total_elements <= 20 + println(" ", value_array) + else + println(" [$(total_elements) elements]") + println(" min: $(minimum(value_array))") + println(" max: $(maximum(value_array))") + println(" mean: $(sum(value_array) / total_elements)") + println(" First 10: $(value_array[1:min(10, total_elements)])") + end + end + + println("═══════════════════════════════════") + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function __init__() + init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_sample_to_trace_ptr = @cfunction( + addSampleToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_subtrace_ptr = @cfunction( + addSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_subtrace::Cstring, add_subtrace_ptr::Ptr{Cvoid} + )::Cvoid + + add_weight_to_trace_ptr = @cfunction(addWeightToTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any})) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_weight_to_trace::Cstring, add_weight_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_retval_to_trace_ptr = @cfunction( + addRetvalToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ), + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_retval_to_trace::Cstring, add_retval_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + get_sample_from_constraint_ptr = @cfunction( + getSampleFromConstraint, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_sample_from_constraint::Cstring, + get_sample_from_constraint_ptr::Ptr{Cvoid}, + )::Cvoid + + get_subconstraint_ptr = @cfunction( + getSubconstraint, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_subconstraint::Cstring, get_subconstraint_ptr::Ptr{Cvoid} + )::Cvoid + + get_sample_from_trace_ptr = @cfunction( + getSampleFromTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_sample_from_trace::Cstring, + get_sample_from_trace_ptr::Ptr{Cvoid}, + )::Cvoid + + get_subtrace_ptr = @cfunction( + getSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_subtrace::Cstring, get_subtrace_ptr::Ptr{Cvoid} + )::Cvoid + + get_weight_from_trace_ptr = @cfunction( + getWeightFromTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_weight_from_trace::Cstring, + get_weight_from_trace_ptr::Ptr{Cvoid}, + )::Cvoid + + get_flattened_samples_from_trace_ptr = @cfunction( + getFlattenedSamplesFromTrace, + Cvoid, + (Ptr{Ptr{Any}}, Ptr{UInt64}, Ptr{UInt64}, Ptr{UInt64}, Ptr{UInt64}, Ptr{Any}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_flattened_samples_from_trace::Cstring, + get_flattened_samples_from_trace_ptr::Ptr{Cvoid}, + )::Cvoid + + dump_ptr = @cfunction( + dump, Cvoid, (Ptr{Any}, Ptr{UInt8}, Ptr{UInt64}, Ptr{UInt64}, Ptr{UInt64}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_dump::Cstring, dump_ptr::Ptr{Cvoid} + )::Cvoid + + return nothing +end diff --git a/src/probprog/HMC.jl b/src/probprog/HMC.jl new file mode 100644 index 0000000000..d523f5827e --- /dev/null +++ b/src/probprog/HMC.jl @@ -0,0 +1,117 @@ +using ..Reactant: TracedRArray + +function hmc( + rng::AbstractRNG, + original_trace, + f::Function, + args::Vararg{Any,Nargs}; + selection::Selection, + mass=nothing, + step_size=nothing, + num_steps=nothing, + initial_momentum=nothing, +) where {Nargs} + args = (rng, args...) + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args, "hmc" + ) + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + trace_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(original_trace)]; outputs=[trace_ty] + ), + 1, + ) + + selection_attr = MLIR.IR.Attribute[] + for address in selection + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(selection_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + accepted_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Bool)) + + alg_attr = @ccall MLIR.API.mlir_c.enzymeMCMCAlgorithmAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, + 0::Int32, # 0 = HMC + )::MLIR.IR.Attribute + + mass_val = nothing + if !isnothing(mass) + mass_val = TracedUtils.get_mlir_data(mass) + end + + step_size_val = nothing + if !isnothing(step_size) + step_size_val = TracedUtils.get_mlir_data(step_size) + end + + num_steps_val = nothing + if !isnothing(num_steps) + num_steps_val = TracedUtils.get_mlir_data(num_steps) + end + + initial_momentum_val = nothing + if !isnothing(initial_momentum) + initial_momentum_val = TracedUtils.get_mlir_data(initial_momentum) + end + + hmc_op = MLIR.Dialects.enzyme.mcmc( + mlir_caller_args, + trace_val, + mass_val; + step_size=step_size_val, + num_steps=num_steps_val, + initial_momentum=initial_momentum_val, + new_trace=trace_ty, + accepted=accepted_ty, + output_rng_state=mlir_result_types[1], # by convention + alg=alg_attr, + fn=fn_attr, + selection=MLIR.IR.Attribute(selection_attr), + ) + + # (new_trace, accepted, output_rng_state) + traced_result = process_probprog_outputs( + hmc_op, + linear_results, + traced_result, + f, + args, + fnwrapped, + resprefix, + argprefix, + 2, + true, + ) + + new_trace_val = MLIR.IR.result(hmc_op, 1) + new_trace_ptr = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [new_trace_val]; outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))] + ), + 1, + ) + + new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ()) + accepted = TracedRArray{Bool,0}((), MLIR.IR.result(hmc_op, 2), ()) + + return new_trace, accepted, traced_result +end diff --git a/src/probprog/MH.jl b/src/probprog/MH.jl new file mode 100644 index 0000000000..2c2c93c8a8 --- /dev/null +++ b/src/probprog/MH.jl @@ -0,0 +1,85 @@ +using ..Reactant: TracedRArray + +function mh( + rng::AbstractRNG, + original_trace, + f::Function, + args::Vararg{Any,Nargs}; + selection::Selection, +) where {Nargs} + args = (rng, args...) + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args, "mh" + ) + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + trace_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(original_trace)]; outputs=[trace_ty] + ), + 1, + ) + + selection_attr = MLIR.IR.Attribute[] + for address in selection + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(selection_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + accepted_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Bool)) + + mh_op = MLIR.Dialects.enzyme.mh( + mlir_caller_args, + trace_val; + new_trace=trace_ty, + accepted=accepted_ty, + output_rng_state=mlir_result_types[1], # by convention + fn=fn_attr, + selection=MLIR.IR.Attribute(selection_attr), + ) + + # Return (new_trace, accepted, output_rng_state) + traced_result = process_probprog_outputs( + mh_op, + linear_results, + traced_result, + f, + args, + fnwrapped, + resprefix, + argprefix, + 2, + true, + ) + + new_trace_val = MLIR.IR.result(mh_op, 1) + new_trace_ptr = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [new_trace_val]; outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))] + ), + 1, + ) + + new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ()) + accepted = TracedRArray{Bool,0}((), MLIR.IR.result(mh_op, 2), ()) + + return new_trace, accepted, traced_result +end + +const metropolis_hastings = mh diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl new file mode 100644 index 0000000000..3dc1b11b16 --- /dev/null +++ b/src/probprog/Modeling.jl @@ -0,0 +1,248 @@ +using ..Reactant: MLIR, TracedUtils, AbstractRNG, TracedRArray +using ..Compiler: @compile + +include("Utils.jl") + +function sample( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, +) where {Nargs} + args_with_rng = (rng, args...) + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args_with_rng, "sample" + ) + + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + symbol_attr = @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64 + )::MLIR.IR.Attribute + + logpdf_attr = nothing + if logpdf isa Function + samples = f(args_with_rng...) + + # Logpdf calling convention: `(sample, args...)` (no rng state) + logpdf_args = (samples, args...) + + logpdf_attr = MLIR.IR.FlatSymbolRefAttribute( + process_probprog_function(logpdf, logpdf_args, "logpdf", false).f_name + ) + end + + sample_op = MLIR.Dialects.enzyme.sample( + mlir_caller_args; + outputs=mlir_result_types, + fn=fn_attr, + logpdf=logpdf_attr, + symbol=symbol_attr, + name=Base.String(symbol), + ) + + traced_result = process_probprog_outputs( + sample_op, + linear_results, + traced_result, + f, + args_with_rng, + fnwrapped, + resprefix, + argprefix, + ) + + return traced_result +end + +function untraced_call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + args_with_rng = (rng, args...) + + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args_with_rng, "untraced_call" + ) + + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) + + call_op = MLIR.Dialects.enzyme.untracedCall( + mlir_caller_args; outputs=mlir_result_types, fn=fn_attr + ) + + traced_result = process_probprog_outputs( + call_op, + linear_results, + traced_result, + f, + args_with_rng, + fnwrapped, + resprefix, + argprefix, + ) + + return traced_result +end + +# Gen-like helper function. +function simulate_(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + trace = nothing + + compiled_fn = @compile optimize = :probprog simulate(rng, f, args...) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer begin + t, _, _ = compiled_fn(rng, f, args...) + trace = ProbProgTrace(t) + end + + return trace, trace.weight +end + +function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + args = (rng, args...) + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args, "simulate" + ) + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + simulate_op = MLIR.Dialects.enzyme.simulate( + mlir_caller_args; + trace=trace_ty, + weight=weight_ty, + outputs=mlir_result_types, + fn=fn_attr, + ) + + traced_result = process_probprog_outputs( + simulate_op, + linear_results, + traced_result, + f, + args, + fnwrapped, + resprefix, + argprefix, + 2, + ) + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(simulate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(simulate_op, 2), ()) + + return trace, weight, traced_result +end + +# Gen-like helper function. +function generate_( + rng::AbstractRNG, constraint::Constraint, f::Function, args::Vararg{Any,Nargs} +) where {Nargs} + trace = nothing + + constrained_addresses = extract_addresses(constraint) + + compiled_fn = @compile optimize = :probprog generate( + rng, constraint, f, args...; constrained_addresses + ) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint begin + t, _, _ = compiled_fn(rng, constraint, f, args...) + trace = ProbProgTrace(t) + end + + return trace, trace.weight +end + +function generate( + rng::AbstractRNG, + constraint, + f::Function, + args::Vararg{Any,Nargs}; + constrained_addresses::Set{Address}, +) where {Nargs} + args = (rng, args...) + + (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function( + f, args, "generate" + ) + + fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name) + + constraint_ty = @ccall MLIR.API.mlir_c.enzymeConstraintTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + constraint_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(constraint)]; outputs=[constraint_ty] + ), + 1, + ) + + constrained_addresses_attr = MLIR.IR.Attribute[] + for address in constrained_addresses + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(constrained_addresses_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + generate_op = MLIR.Dialects.enzyme.generate( + mlir_caller_args, + constraint_val; + trace=trace_ty, + weight=weight_ty, + outputs=mlir_result_types, + fn=fn_attr, + constrained_addresses=MLIR.IR.Attribute(constrained_addresses_attr), + ) + + traced_result = process_probprog_outputs( + generate_op, + linear_results, + traced_result, + f, + args, + fnwrapped, + resprefix, + argprefix, + 2, + ) + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(generate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(generate_op, 2), ()) + + return trace, weight, traced_result +end diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl new file mode 100644 index 0000000000..2c9c043538 --- /dev/null +++ b/src/probprog/ProbProg.jl @@ -0,0 +1,26 @@ +module ProbProg + +using ..Reactant: + MLIR, TracedUtils, AbstractRNG, TracedRArray, TracedRNumber, ConcreteRNumber +using ..Compiler: @jit, @compile + +include("Types.jl") +include("FFI.jl") +include("Modeling.jl") +include("Display.jl") +include("MH.jl") +include("HMC.jl") + +# Types. +export ProbProgTrace, Constraint, Selection, Address + +# Utility functions. +export get_choices, select + +# Core MLIR ops. +export sample, untraced_call, simulate, generate, mh, hmc + +# Gen-like helper functions. +export simulate_, generate_ + +end diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl new file mode 100644 index 0000000000..c1c94d7031 --- /dev/null +++ b/src/probprog/Types.jl @@ -0,0 +1,84 @@ +using Base: ReentrantLock +using ..Reactant: AbstractConcreteNumber, AbstractConcreteArray + +mutable struct ProbProgTrace + choices::Dict{Symbol,Any} + retval::Any + weight::Any + subtraces::Dict{Symbol,Any} + + function ProbProgTrace() + return new(Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}()) + end + function ProbProgTrace(x::Union{AbstractConcreteNumber,AbstractConcreteArray}) + return convert(ProbProgTrace, x) + end +end + +struct Address + path::Vector{Symbol} + + Address(path::Vector{Symbol}) = new(path) +end + +Address(sym::Symbol) = Address([sym]) +Address(syms::Symbol...) = Address([syms...]) + +Base.:(==)(a::Address, b::Address) = a.path == b.path +Base.hash(a::Address, h::UInt) = hash(a.path, h) + +mutable struct Constraint <: AbstractDict{Address,Any} + dict::Dict{Address,Any} + + function Constraint(pairs::Pair...) + dict = Dict{Address,Any}() + for pair in pairs + symbols = Symbol[] + current = pair + while isa(current, Pair) && isa(current.first, Symbol) + push!(symbols, current.first) + current = current.second + end + dict[Address(symbols...)] = current + end + return new(dict) + end + + Constraint() = new(Dict{Address,Any}()) + Constraint(d::Dict{Address,Any}) = new(d) + function Constraint(x::Union{AbstractConcreteNumber,AbstractConcreteArray}) + return convert(Constraint, x) + end +end + +Base.getindex(c::Constraint, k::Address) = c.dict[k] +Base.setindex!(c::Constraint, v, k::Address) = (c.dict[k] = v) +Base.delete!(c::Constraint, k::Address) = delete!(c.dict, k) +Base.keys(c::Constraint) = keys(c.dict) +Base.values(c::Constraint) = values(c.dict) +Base.iterate(c::Constraint) = iterate(c.dict) +Base.iterate(c::Constraint, state) = iterate(c.dict, state) +Base.length(c::Constraint) = length(c.dict) +Base.isempty(c::Constraint) = isempty(c.dict) +Base.haskey(c::Constraint, k::Address) = haskey(c.dict, k) +Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default) + +extract_addresses(constraint::Constraint) = Set(keys(constraint)) + +const Selection = Set{Address} + +const _probprog_ref_lock = ReentrantLock() +const _probprog_refs = IdDict() + +function _keepalive!(tr::Any) + lock(_probprog_ref_lock) + try + _probprog_refs[tr] = tr + finally + unlock(_probprog_ref_lock) + end + return tr +end + +get_choices(trace::ProbProgTrace) = trace.choices +select(addrs::Address...) = Set{Address}([addrs...]) diff --git a/src/probprog/Utils.jl b/src/probprog/Utils.jl new file mode 100644 index 0000000000..2907ba312d --- /dev/null +++ b/src/probprog/Utils.jl @@ -0,0 +1,197 @@ +using ..Reactant: + MLIR, + TracedUtils, + Ops, + TracedRArray, + TracedRNumber, + Compiler, + OrderedIdDict, + TracedToTypes, + TracedType, + TracedTrack, + TracedSetPath, + ConcreteToTraced, + AbstractConcreteArray, + Sharding, + to_number +import ..Reactant: promote_to, make_tracer +import ..Compiler: donate_argument! + +function process_probprog_function(f, args, op_name, with_rng=true) + seen = OrderedIdDict() + cache_key = [] + make_tracer(seen, (f, args...), cache_key, TracedToTypes) + cache = Compiler.callcache() + + if haskey(cache, cache_key) + (; f_name, mlir_result_types, traced_result, mutated_args, linear_results, fnwrapped, argprefix, resprefix, resargprefix) = cache[cache_key] + else + f_name = String(gensym(Symbol(f))) + argprefix::Symbol = gensym(op_name * "arg") + resprefix::Symbol = gensym(op_name * "result") + resargprefix::Symbol = gensym(op_name * "resarg") + + wrapper_fn = if !with_rng + f + else + (all_args...) -> begin + res = f(all_args...) + (all_args[1], (res isa Tuple ? res : (res,))...) + end + end + + temp = TracedUtils.make_mlir_fn( + wrapper_fn, + args, + (), + f_name, + false; + do_transpose=false, + args_in_result=:result, + argprefix, + resprefix, + resargprefix, + ) + + (; traced_result, ret, mutated_args, linear_results, fnwrapped) = temp + mlir_result_types = [ + MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret) + ] + cache[cache_key] = (; + f_name, + mlir_result_types, + traced_result, + mutated_args, + linear_results, + fnwrapped, + argprefix, + resprefix, + resargprefix, + ) + end + + seen_cache = OrderedIdDict() + make_tracer(seen_cache, fnwrapped ? (f, args) : args, (), TracedTrack; toscalar=false) + linear_args = [] + mlir_caller_args = MLIR.IR.Value[] + for (_, v) in seen_cache + v isa TracedType || continue + push!(linear_args, v) + push!(mlir_caller_args, v.mlir_data) + v.paths = v.paths[1:(end - 1)] + end + + return (; + f_name, + linear_args, + mlir_caller_args, + mlir_result_types, + traced_result, + linear_results, + fnwrapped, + argprefix, + resprefix, + resargprefix, + ) +end + +function process_probprog_outputs( + op, + linear_results, + traced_result, + f, + args, + fnwrapped, + resprefix, + argprefix, + offset=0, + rng_only=false, +) + seen_results = OrderedIdDict() + traced_result = make_tracer( + seen_results, traced_result, (), TracedSetPath; toscalar=false + ) + + num_to_process = rng_only ? 1 : length(linear_results) + + for i in 1:num_to_process + res = linear_results[i] + resv = MLIR.IR.result(op, i + offset) + + for path in res.paths + if length(path) == 0 + continue + end + if path[1] == resprefix + TracedUtils.set!(traced_result, path[2:end], resv) + elseif path[1] == argprefix + idx = path[2]::Int + if fnwrapped && idx == 2 + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrapped && idx > 2 + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + end + end + end + + return traced_result +end + +function promote_to(::Type{TracedRArray{UInt64,0}}, t::Union{ProbProgTrace,Constraint}) + return Ops.fill(reinterpret(UInt64, pointer_from_objref(t)), Int64[]) +end + +function Base.convert( + ::Type{T}, x::AbstractConcreteArray +) where {T<:Union{ProbProgTrace,Constraint}} + while !isready(x) + yield() + end + return unsafe_pointer_to_objref(Ptr{Any}(collect(x)[1]))::T +end + +function Base.convert( + ::Type{T}, x::AbstractConcreteNumber +) where {T<:Union{ProbProgTrace,Constraint}} + while !isready(x) + yield() + end + return unsafe_pointer_to_objref(Ptr{Any}(to_number(x)))::T +end + +function Base.getproperty(t::Union{ProbProgTrace,Constraint}, s::Symbol) + if s === :data + return ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(t))).data + else + return getfield(t, s) + end +end + +function donate_argument!(::Any, ::Union{ProbProgTrace,Constraint}, ::Int, ::Any, ::Any) + return nothing +end + +Base.@nospecializeinfer function make_tracer( + seen, + @nospecialize(prev::Union{ProbProgTrace,Constraint}), + @nospecialize(path), + mode; + @nospecialize(sharding = Sharding.NoSharding()), + kwargs..., +) + if mode == ConcreteToTraced + haskey(seen, prev) && return seen[prev]::TracedRNumber{UInt64} + result = TracedRNumber{UInt64}((path,), nothing) + seen[prev] = result + return result + elseif mode == TracedToTypes + push!(path, typeof(prev)) + return nothing + else + error("Unsupported mode for $(typeof(prev)): $mode") + end +end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl new file mode 100644 index 0000000000..403f2e67e3 --- /dev/null +++ b/test/probprog/generate.jl @@ -0,0 +1,133 @@ +using Reactant, Test, Random, Statistics +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function two_normals(rng, μ, σ, shape) + _, x = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:x, logpdf=normal_logpdf) + _, y = ProbProg.sample(rng, normal, x, σ, shape; symbol=:y, logpdf=normal_logpdf) + return y +end + +function nested_model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, two_normals, s, σ, shape; symbol=:t) + _, u = ProbProg.sample(rng, two_normals, t, σ, shape; symbol=:u) + return u +end + +@testset "Generate" begin + @testset "unconstrained" begin + shape = (1000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + trace, weight = ProbProg.generate_(rng, ProbProg.Constraint(), model, μ, σ, shape) + @test mean(trace.retval[1]) ≈ 0.0 atol = 0.05 rtol = 0.05 + end + + @testset "constrained" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint(:s => (fill(0.1, shape),)) + + trace, weight = ProbProg.generate_(rng, constraint, model, μ, σ, shape) + + @test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1] + + expected_weight = + normal_logpdf(constraint[ProbProg.Address(:s)][1], 0.0, 1.0, shape) + + normal_logpdf( + trace.choices[:t][1], constraint[ProbProg.Address(:s)][1], 1.0, shape + ) + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "composite addresses" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint( + :s => (fill(0.1, shape),), + :t => :x => (fill(0.2, shape),), + :u => :y => (fill(0.3, shape),), + ) + + trace, weight = ProbProg.generate_(rng, constraint, nested_model, μ, σ, shape) + + @test trace.choices[:s][1] == fill(0.1, shape) + @test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape) + @test trace.subtraces[:u].choices[:y][1] == fill(0.3, shape) + + s_weight = normal_logpdf(fill(0.1, shape), 0.0, 1.0, shape) + tx_weight = normal_logpdf(fill(0.2, shape), fill(0.1, shape), 1.0, shape) + ty_weight = normal_logpdf( + trace.subtraces[:t].choices[:y][1], fill(0.2, shape), 1.0, shape + ) + ux_weight = normal_logpdf( + trace.subtraces[:u].choices[:x][1], + trace.subtraces[:t].choices[:y][1], + 1.0, + shape, + ) + uy_weight = normal_logpdf( + fill(0.3, shape), trace.subtraces[:u].choices[:x][1], 1.0, shape + ) + + expected_weight = s_weight + tx_weight + ty_weight + ux_weight + uy_weight + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "compiled" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint1 = ProbProg.Constraint(:s => (fill(0.1, shape),)) + + constrained_addresses = ProbProg.extract_addresses(constraint1) + + compiled_fn = @compile optimize = :probprog ProbProg.generate( + rng, constraint1, model, μ, σ, shape; constrained_addresses + ) + + trace1 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint1 begin + trace1, _ = compiled_fn(rng, constraint1, model, μ, σ, shape) + trace1 = ProbProg.ProbProgTrace(trace1) + end + + constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),)) + + trace2 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint2 begin + trace2, _ = compiled_fn(rng, constraint2, model, μ, σ, shape) + trace2 = ProbProg.ProbProgTrace(trace2) + end + + @test trace1.choices[:s][1] != trace2.choices[:s][1] + end +end diff --git a/test/probprog/hmc.jl b/test/probprog/hmc.jl new file mode 100644 index 0000000000..eb5b967633 --- /dev/null +++ b/test/probprog/hmc.jl @@ -0,0 +1,130 @@ +using Reactant, Test, Random +using Statistics +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, xs) + _, param_a = ProbProg.sample( + rng, normal, 0.0, 5.0, (1,); symbol=:param_a, logpdf=normal_logpdf + ) + _, param_b = ProbProg.sample( + rng, normal, 0.0, 5.0, (1,); symbol=:param_b, logpdf=normal_logpdf + ) + + _, ys_a = ProbProg.sample( + rng, normal, param_a .+ xs[1:5], 0.5, (5,); symbol=:ys_a, logpdf=normal_logpdf + ) + + _, ys_b = ProbProg.sample( + rng, normal, param_b .+ xs[6:10], 0.5, (5,); symbol=:ys_b, logpdf=normal_logpdf + ) + + return vcat(ys_a, ys_b) +end + +function hmc_program( + rng, + model, + xs, + step_size, + num_steps, + mass, + initial_momentum, + constraint, + constrained_addresses, +) + t, _, _ = ProbProg.generate(rng, constraint, model, xs; constrained_addresses) + + t, accepted, _ = ProbProg.hmc( + rng, + t, + model, + xs; + selection=ProbProg.select(ProbProg.Address(:param_a), ProbProg.Address(:param_b)), + mass, + step_size, + num_steps, + initial_momentum, + ) + + return t, accepted +end + +@testset "hmc" begin + seed = Reactant.to_rarray(UInt64[1, 5]) + rng = ReactantRNG(seed) + + xs = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5] + ys_a = [-2.3, -1.6, -0.4, 0.6, 1.4] + ys_b = [-2.6, -1.4, -0.6, 0.4, 1.6] + obs = ProbProg.Constraint( + :param_a => ([0.0],), :param_b => ([0.0],), :ys_a => (ys_a,), :ys_b => (ys_b,) + ) + constrained_addresses = ProbProg.extract_addresses(obs) + + step_size = ConcreteRNumber(0.001) + num_steps_compile = ConcreteRNumber(1000) + num_steps_run = ConcreteRNumber(40000000) + mass = nothing + initial_momentum = ConcreteRArray([0.0, 0.0]) + + code = @code_hlo optimize = :probprog hmc_program( + rng, + model, + xs, + step_size, + num_steps_compile, + mass, + initial_momentum, + obs, + constrained_addresses, + ) + @test contains(repr(code), "enzyme_probprog_get_flattened_samples_from_trace") + @test contains(repr(code), "enzyme_probprog_get_weight_from_trace") + @test !contains(repr(code), "enzyme.mcmc") + + compile_time_s = @elapsed begin + compiled_fn = @compile optimize = :probprog hmc_program( + rng, + model, + xs, + step_size, + num_steps_compile, + mass, + initial_momentum, + obs, + constrained_addresses, + ) + end + println("HMC compile time: $(round(compile_time_s * 1000, digits=2)) ms") + + seed_buffer = only(rng.seed.data).buffer + trace = nothing + GC.@preserve seed_buffer obs begin + run_time_s = @elapsed begin + trace, _ = compiled_fn( + rng, + model, + xs, + step_size, + num_steps_run, + mass, + initial_momentum, + obs, + constrained_addresses, + ) + trace = ProbProg.ProbProgTrace(trace) + end + println("HMC run time: $(round(run_time_s * 1000, digits=2)) ms") + end + + # NumPyro results + @test only(trace.choices[:param_a])[1] ≈ 0.01327671 rtol = 1e-6 + @test only(trace.choices[:param_b])[1] ≈ -0.01965474 rtol = 1e-6 +end diff --git a/test/probprog/mh.jl b/test/probprog/mh.jl new file mode 100644 index 0000000000..f331d8abd1 --- /dev/null +++ b/test/probprog/mh.jl @@ -0,0 +1,103 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, xs) + _, slope = ProbProg.sample( + rng, normal, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf + ) + _, intercept = ProbProg.sample( + rng, normal, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf + ) + + _, ys = ProbProg.sample( + rng, + normal, + slope .* xs .+ intercept, + 1.0, + (length(xs),); + symbol=:ys, + logpdf=normal_logpdf, + ) + + return ys +end + +function mh_program(rng, model, xs, num_iters, constraint, constrained_addresses) + init_trace, _, _ = ProbProg.generate( + rng, constraint, model, xs; constrained_addresses=constrained_addresses + ) + + trace = init_trace + @trace for _ in 1:num_iters + trace, _ = ProbProg.mh( + rng, trace, model, xs; selection=ProbProg.select(ProbProg.Address(:slope)) + ) + trace, _ = ProbProg.mh( + rng, trace, model, xs; selection=ProbProg.select(ProbProg.Address(:intercept)) + ) + end + + return trace +end + +@testset "linear_regression" begin + @testset "simulate" begin + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + xs_r = Reactant.to_rarray(xs) + + trace, _ = ProbProg.simulate_(rng, model, xs_r) + + @test haskey(trace.choices, :slope) + @test haskey(trace.choices, :intercept) + @test haskey(trace.choices, :ys) + end + + @testset "inference" begin + seed = Reactant.to_rarray(UInt64[1, 5]) + rng = ReactantRNG(seed) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] + obs = ProbProg.Constraint(:ys => (ys,)) + num_iters = ConcreteRNumber(10000) + constrained_addresses = ProbProg.extract_addresses(obs) + + code = @code_hlo optimize = :probprog mh_program( + rng, model, xs, num_iters, obs, constrained_addresses + ) + @test contains(repr(code), "enzyme_probprog_get_sample_from_trace") + @test contains(repr(code), "enzyme_probprog_get_weight_from_trace") + @test !contains(repr(code), "enzyme.mh") + + compiled_fn = @compile optimize = :probprog mh_program( + rng, model, xs, num_iters, obs, constrained_addresses + ) + + trace = nothing + seed_buffer = only(rng.seed.data).buffer + num_iters = ConcreteRNumber(1000) + GC.@preserve seed_buffer obs begin + trace = compiled_fn(rng, model, xs, num_iters, obs, constrained_addresses) + trace = ProbProg.ProbProgTrace(trace) + end + + slope = only(trace.choices[:slope])[1] + intercept = only(trace.choices[:intercept])[1] + @show slope, intercept + + @test slope ≈ -2.0 rtol = 0.1 + @test intercept ≈ 10.0 rtol = 0.1 + end +end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl new file mode 100644 index 0000000000..410fa716a1 --- /dev/null +++ b/test/probprog/sample.jl @@ -0,0 +1,95 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function one_sample(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; logpdf=normal_logpdf) + return s +end + +function two_samples(rng, μ, σ, shape) + _ = ProbProg.sample(rng, normal, μ, σ, shape; logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, μ, σ, shape; logpdf=normal_logpdf) + return t +end + +function compose(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape) + _, t = ProbProg.sample(rng, normal, s, σ, shape) + return t +end + +@testset "test" begin + @testset "normal_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample( + rng, normal, μ, σ, shape; logpdf=normal_logpdf + ) + @test contains(repr(code), "enzyme.sample") + end + + @testset "two_samples_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample(rng, two_samples, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "compose" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.untraced_call( + rng, compose, μ, σ, shape + ) + @test contains(repr(before), "enzyme.sample") + + after = @code_hlo optimize = :probprog ProbProg.untraced_call( + rng, compose, μ, σ, shape + ) + @test !contains(repr(after), "enzyme.sample") + end + + @testset "rng_state" begin + shape = (10,) + + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + rng1 = ReactantRNG(copy(seed)) + + _, X = @jit optimize = :probprog ProbProg.untraced_call( + rng1, one_sample, μ, σ, shape + ) + @test !all(rng1.seed .== seed) + + rng2 = ReactantRNG(copy(seed)) + _, Y = @jit optimize = :probprog ProbProg.untraced_call( + rng2, two_samples, μ, σ, shape + ) + + @test !all(rng2.seed .== seed) + @test !all(rng2.seed .== rng1.seed) + + @test !all(X .≈ Y) + end +end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl new file mode 100644 index 0000000000..a1b0cec3d8 --- /dev/null +++ b/test/probprog/simulate.jl @@ -0,0 +1,154 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function product_two_normals(rng, μ, σ, shape) + _, a = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) + _, b = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:b, logpdf=normal_logpdf) + return a .* b +end + +function model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function model2(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:s) + _, t = ProbProg.sample(rng, product_two_normals, s, σ, shape; symbol=:t) + return t +end + +function model3(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:s) + _, t = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:t) + return s, t +end + +@testset "Simulate" begin + @testset "hlo" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.simulate(rng, model, μ, σ, shape) + @test contains(repr(before), "enzyme.simulate") + + after = @code_hlo optimize = :probprog ProbProg.simulate(rng, model, μ, σ, shape) + @test !contains(repr(after), "enzyme.simulate") + @test !contains(repr(after), "enzyme.addSampleToTrace") + @test !contains(repr(after), "enzyme.addWeightToTrace") + @test !contains(repr(after), "enzyme.addRetvalToTrace") + end + + @testset "normal_simulate" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate_(rng, model, μ, σ, shape) + + @test size(trace.retval[1]) == shape + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + @test trace.weight isa Float64 + end + + @testset "simple_fake" begin + op(_, x, y) = x * y' + logpdf(res, _, _) = sum(res) + function fake_model(rng, x, y) + _, res = ProbProg.sample(rng, op, x, y; symbol=:matmul, logpdf=logpdf) + return res + end + + x = reshape(collect(Float64, 1:12), (4, 3)) + y = reshape(collect(Float64, 1:12), (4, 3)) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + trace, weight = ProbProg.simulate_(rng, fake_model, x_ra, y_ra) + + @test Array(trace.retval[1]) == op(rng, x, y) + @test haskey(trace.choices, :matmul) + @test trace.choices[:matmul][1] == op(rng, x, y) + @test trace.weight == logpdf(op(rng, x, y), x, y) + end + + @testset "submodel_fake" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate_(rng, model2, μ, σ, shape) + + @test size(trace.retval[1]) == shape + + @test length(trace.choices) == 2 + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + + @test length(trace.subtraces) == 2 + @test haskey(trace.subtraces[:s].choices, :a) + @test haskey(trace.subtraces[:s].choices, :b) + @test haskey(trace.subtraces[:t].choices, :a) + @test haskey(trace.subtraces[:t].choices, :b) + @test trace.subtraces[:s].choices[:a][1] !== trace.subtraces[:t].choices[:a][1] + @test trace.subtraces[:s].choices[:b][1] !== trace.subtraces[:t].choices[:b][1] + + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + + @test trace.weight isa Float64 + + @test trace.weight ≈ trace.subtraces[:s].weight + trace.subtraces[:t].weight + end + + @testset "submodel_subtraces" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate_(rng, model3, μ, σ, shape) + + @test size(trace.retval[1]) == shape + + @test length(trace.choices) == 2 + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + + @test length(trace.subtraces) == 2 + @test haskey(trace.subtraces[:s].choices, :a) + @test haskey(trace.subtraces[:s].choices, :b) + @test haskey(trace.subtraces[:t].choices, :a) + @test haskey(trace.subtraces[:t].choices, :b) + @test trace.subtraces[:s].choices[:a][1] !== trace.subtraces[:t].choices[:a][1] + @test trace.subtraces[:s].choices[:b][1] !== trace.subtraces[:t].choices[:b][1] + + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + + @test trace.weight isa Float64 + + @test trace.weight ≈ trace.subtraces[:s].weight + trace.subtraces[:t].weight + end +end diff --git a/test/runtests.jl b/test/runtests.jl index bc10705e43..ffa6d6cba0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -73,4 +73,12 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Lux Integration" include("nn/lux.jl") end end + + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "probprog" + @safetestset "ProbProg Sample" include("probprog/sample.jl") + @safetestset "ProbProg Simulate" include("probprog/simulate.jl") + @safetestset "ProbProg Generate" include("probprog/generate.jl") + @safetestset "ProbProg MH" include("probprog/mh.jl") + @safetestset "ProbProg HMC" include("probprog/hmc.jl") + end end