Skip to content

Commit 8200ce5

Browse files
committed
static HMC
1 parent 2a979f8 commit 8200ce5

File tree

18 files changed

+2266
-0
lines changed

18 files changed

+2266
-0
lines changed

src/CompileOptions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ function CompileOptions(;
229229
:canonicalize,
230230
:just_batch,
231231
:none,
232+
:probprog,
232233
]
233234
end
234235

src/Compiler.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,7 @@ end
13111311
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
13121312
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
13131313
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
1314+
const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true}\"}"
13141315

13151316
function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true)
13161317
pm = MLIR.IR.PassManager()
@@ -1898,6 +1899,71 @@ function compile_mlir!(
18981899
),
18991900
"no_enzyme",
19001901
)
1902+
elseif compile_options.optimization_passes === :probprog
1903+
run_pass_pipeline!(
1904+
mod,
1905+
join(
1906+
if compile_options.raise_first
1907+
[
1908+
"mark-func-memory-effects",
1909+
opt_passes,
1910+
kern,
1911+
raise_passes,
1912+
"enzyme-batch",
1913+
opt_passes2,
1914+
probprog_pass,
1915+
"lower-probprog-to-stablehlo{backend=$backend}",
1916+
"outline-enzyme-regions",
1917+
enzyme_pass,
1918+
opt_passes2,
1919+
"canonicalize",
1920+
"remove-unnecessary-enzyme-ops",
1921+
"enzyme-simplify-math",
1922+
(
1923+
if compile_options.legalize_chlo_to_stablehlo
1924+
["func.func(chlo-legalize-to-stablehlo)"]
1925+
else
1926+
[]
1927+
end
1928+
)...,
1929+
opt_passes2,
1930+
lower_enzymexla_linalg_pass,
1931+
"lower-probprog-trace-ops{backend=$backend}",
1932+
jit,
1933+
]
1934+
else
1935+
[
1936+
"mark-func-memory-effects",
1937+
opt_passes,
1938+
"enzyme-batch",
1939+
opt_passes2,
1940+
probprog_pass,
1941+
"lower-probprog-to-stablehlo{backend=$backend}",
1942+
"outline-enzyme-regions",
1943+
enzyme_pass,
1944+
opt_passes2,
1945+
"canonicalize",
1946+
"remove-unnecessary-enzyme-ops",
1947+
"enzyme-simplify-math",
1948+
(
1949+
if compile_options.legalize_chlo_to_stablehlo
1950+
["func.func(chlo-legalize-to-stablehlo)"]
1951+
else
1952+
[]
1953+
end
1954+
)...,
1955+
opt_passes2,
1956+
kern,
1957+
raise_passes,
1958+
lower_enzymexla_linalg_pass,
1959+
"lower-probprog-trace-ops{backend=$backend}",
1960+
jit,
1961+
]
1962+
end,
1963+
",",
1964+
),
1965+
"probprog",
1966+
)
19011967
elseif compile_options.optimization_passes === :only_enzyme
19021968
run_pass_pipeline!(
19031969
mod,

src/Reactant.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ include("Tracing.jl")
246246
include("Compiler.jl")
247247

248248
include("Overlay.jl")
249+
include("probprog/ProbProg.jl")
249250

250251
# Serialization
251252
include("serialization/Serialization.jl")

src/Types.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ function ConcretePJRTArray(
241241
end
242242

243243
Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
244+
Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data)
244245
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data)
245246
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber})
246247
x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data))
@@ -420,6 +421,7 @@ function ConcreteIFRTArray(
420421
end
421422

422423
Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data)
424+
Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data)
423425
XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data)
424426
function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber})
425427
return XLA.device(x.data)

src/probprog/Display.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104
2+
function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
3+
VERT = '\u2502'
4+
PLUS = '\u251C'
5+
HORZ = '\u2500'
6+
LAST = '\u2514'
7+
8+
indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
9+
indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' '])
10+
indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' '])
11+
12+
for i in vert_bars
13+
indent_vert[i] = VERT
14+
indent[i] = VERT
15+
indent_last[i] = VERT
16+
end
17+
18+
indent_vert_str = join(indent_vert)
19+
indent_str = join(indent)
20+
indent_last_str = join(indent_last)
21+
22+
sorted_choices = sort(collect(trace.choices); by=x -> x[1])
23+
n = length(sorted_choices)
24+
25+
if trace.retval !== nothing
26+
n += 1
27+
end
28+
29+
if trace.weight !== nothing
30+
n += 1
31+
end
32+
33+
cur = 1
34+
35+
if trace.retval !== nothing
36+
print(io, indent_vert_str)
37+
print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n")
38+
cur += 1
39+
end
40+
41+
if trace.weight !== nothing
42+
print(io, indent_vert_str)
43+
print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n")
44+
cur += 1
45+
end
46+
47+
for (key, value) in sorted_choices
48+
print(io, indent_vert_str)
49+
print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n")
50+
cur += 1
51+
end
52+
53+
sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1])
54+
n += length(sorted_subtraces)
55+
56+
for (key, subtrace) in sorted_subtraces
57+
print(io, indent_vert_str)
58+
print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n")
59+
_show_pretty(
60+
io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1)
61+
)
62+
cur += 1
63+
end
64+
end
65+
66+
function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace)
67+
println(io, "ProbProgTrace:")
68+
if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing
69+
println(io, " (empty)")
70+
else
71+
_show_pretty(io, trace, 0, ())
72+
end
73+
end
74+
75+
function Base.show(io::IO, trace::ProbProgTrace)
76+
if get(io, :compact, false)
77+
choices_count = length(trace.choices)
78+
has_retval = trace.retval !== nothing
79+
print(io, "ProbProgTrace($(choices_count) choices")
80+
if has_retval
81+
print(io, ", retval=$(trace.retval), weight=$(trace.weight)")
82+
end
83+
print(io, ")")
84+
else
85+
show(io, MIME"text/plain"(), trace)
86+
end
87+
end

0 commit comments

Comments
 (0)