Skip to content

Commit e120768

Browse files
committed
static HMC
1 parent 837e752 commit e120768

File tree

19 files changed

+2278
-0
lines changed

19 files changed

+2278
-0
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,18 @@ enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) {
388388
return wrap(attr);
389389
}
390390

391+
extern "C" MLIR_CAPI_EXPORTED MlirAttribute
392+
enzymeRngDistributionAttrGet(MlirContext ctx, int32_t val) {
393+
return wrap(mlir::enzyme::RngDistributionAttr::get(
394+
unwrap(ctx), (mlir::enzyme::RngDistribution)val));
395+
}
396+
397+
extern "C" MLIR_CAPI_EXPORTED MlirAttribute
398+
enzymeMCMCAlgorithmAttrGet(MlirContext ctx, int32_t val) {
399+
return wrap(mlir::enzyme::MCMCAlgorithmAttr::get(
400+
unwrap(ctx), (mlir::enzyme::MCMCAlgorithm)val));
401+
}
402+
391403
// Create profiler session and start profiling
392404
REACTANT_ABI tsl::ProfilerSession *
393405
CreateProfilerSession(uint32_t device_tracer_level,

src/CompileOptions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ function CompileOptions(;
229229
:canonicalize,
230230
:just_batch,
231231
:none,
232+
:probprog,
232233
]
233234
end
234235

src/Compiler.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,7 @@ end
13111311
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
13121312
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
13131313
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
1314+
const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true}\"}"
13141315

13151316
function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true)
13161317
pm = MLIR.IR.PassManager()
@@ -1893,6 +1894,71 @@ function compile_mlir!(
18931894
),
18941895
"no_enzyme",
18951896
)
1897+
elseif compile_options.optimization_passes === :probprog
1898+
run_pass_pipeline!(
1899+
mod,
1900+
join(
1901+
if compile_options.raise_first
1902+
[
1903+
"mark-func-memory-effects",
1904+
opt_passes,
1905+
kern,
1906+
raise_passes,
1907+
"enzyme-batch",
1908+
opt_passes2,
1909+
probprog_pass,
1910+
"lower-probprog-to-stablehlo{backend=$backend}",
1911+
"outline-enzyme-regions",
1912+
enzyme_pass,
1913+
opt_passes2,
1914+
"canonicalize",
1915+
"remove-unnecessary-enzyme-ops",
1916+
"enzyme-simplify-math",
1917+
(
1918+
if compile_options.legalize_chlo_to_stablehlo
1919+
["func.func(chlo-legalize-to-stablehlo)"]
1920+
else
1921+
[]
1922+
end
1923+
)...,
1924+
opt_passes2,
1925+
lower_enzymexla_linalg_pass,
1926+
"lower-probprog-trace-ops{backend=$backend}",
1927+
jit,
1928+
]
1929+
else
1930+
[
1931+
"mark-func-memory-effects",
1932+
opt_passes,
1933+
"enzyme-batch",
1934+
opt_passes2,
1935+
probprog_pass,
1936+
"lower-probprog-to-stablehlo{backend=$backend}",
1937+
"outline-enzyme-regions",
1938+
enzyme_pass,
1939+
opt_passes2,
1940+
"canonicalize",
1941+
"remove-unnecessary-enzyme-ops",
1942+
"enzyme-simplify-math",
1943+
(
1944+
if compile_options.legalize_chlo_to_stablehlo
1945+
["func.func(chlo-legalize-to-stablehlo)"]
1946+
else
1947+
[]
1948+
end
1949+
)...,
1950+
opt_passes2,
1951+
kern,
1952+
raise_passes,
1953+
lower_enzymexla_linalg_pass,
1954+
"lower-probprog-trace-ops{backend=$backend}",
1955+
jit,
1956+
]
1957+
end,
1958+
",",
1959+
),
1960+
"probprog",
1961+
)
18961962
elseif compile_options.optimization_passes === :only_enzyme
18971963
run_pass_pipeline!(
18981964
mod,

src/Reactant.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ include("Tracing.jl")
246246
include("Compiler.jl")
247247

248248
include("Overlay.jl")
249+
include("probprog/ProbProg.jl")
249250

250251
# Serialization
251252
include("serialization/Serialization.jl")

src/Types.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ function ConcretePJRTArray(
241241
end
242242

243243
Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
244+
Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data)
244245
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data)
245246
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber})
246247
x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data))
@@ -420,6 +421,7 @@ function ConcreteIFRTArray(
420421
end
421422

422423
Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data)
424+
Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data)
423425
XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data)
424426
function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber})
425427
return XLA.device(x.data)

src/probprog/Display.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104
2+
function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
3+
VERT = '\u2502'
4+
PLUS = '\u251C'
5+
HORZ = '\u2500'
6+
LAST = '\u2514'
7+
8+
indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
9+
indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' '])
10+
indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' '])
11+
12+
for i in vert_bars
13+
indent_vert[i] = VERT
14+
indent[i] = VERT
15+
indent_last[i] = VERT
16+
end
17+
18+
indent_vert_str = join(indent_vert)
19+
indent_str = join(indent)
20+
indent_last_str = join(indent_last)
21+
22+
sorted_choices = sort(collect(trace.choices); by=x -> x[1])
23+
n = length(sorted_choices)
24+
25+
if trace.retval !== nothing
26+
n += 1
27+
end
28+
29+
if trace.weight !== nothing
30+
n += 1
31+
end
32+
33+
cur = 1
34+
35+
if trace.retval !== nothing
36+
print(io, indent_vert_str)
37+
print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n")
38+
cur += 1
39+
end
40+
41+
if trace.weight !== nothing
42+
print(io, indent_vert_str)
43+
print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n")
44+
cur += 1
45+
end
46+
47+
for (key, value) in sorted_choices
48+
print(io, indent_vert_str)
49+
print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n")
50+
cur += 1
51+
end
52+
53+
sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1])
54+
n += length(sorted_subtraces)
55+
56+
for (key, subtrace) in sorted_subtraces
57+
print(io, indent_vert_str)
58+
print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n")
59+
_show_pretty(
60+
io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1)
61+
)
62+
cur += 1
63+
end
64+
end
65+
66+
function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace)
67+
println(io, "ProbProgTrace:")
68+
if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing
69+
println(io, " (empty)")
70+
else
71+
_show_pretty(io, trace, 0, ())
72+
end
73+
end
74+
75+
function Base.show(io::IO, trace::ProbProgTrace)
76+
if get(io, :compact, false)
77+
choices_count = length(trace.choices)
78+
has_retval = trace.retval !== nothing
79+
print(io, "ProbProgTrace($(choices_count) choices")
80+
if has_retval
81+
print(io, ", retval=$(trace.retval), weight=$(trace.weight)")
82+
end
83+
print(io, ")")
84+
else
85+
show(io, MIME"text/plain"(), trace)
86+
end
87+
end

0 commit comments

Comments
 (0)