Skip to content

Commit 5c96c8e

Browse files
committed
clean up
1 parent d3196b4 commit 5c96c8e

File tree

2 files changed

+14
-37
lines changed

2 files changed

+14
-37
lines changed

src/probprog/HMC.jl

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using ..Reactant: ConcreteRNumber, TracedRArray
22

33
function hmc(
44
rng::AbstractRNG,
5-
original_trace::Union{ProbProgTrace,TracedRArray{UInt64,0}},
5+
original_trace,
66
f::Function,
77
args::Vararg{Any,Nargs};
88
selection::Selection,
@@ -21,23 +21,12 @@ function hmc(
2121
MLIR.IR.context()::MLIR.API.MlirContext
2222
)::MLIR.IR.Type
2323

24-
trace_val = if original_trace isa TracedRArray{UInt64,0}
25-
MLIR.IR.result(
26-
MLIR.Dialects.builtin.unrealized_conversion_cast(
27-
[original_trace.mlir_data]; outputs=[trace_ty]
28-
),
29-
1,
30-
)
31-
else
32-
# First iteration: promote a ProbProgTrace to tensor<ui64>
33-
promoted = to_trace_tensor(original_trace)
34-
MLIR.IR.result(
35-
MLIR.Dialects.builtin.unrealized_conversion_cast(
36-
[TracedUtils.get_mlir_data(promoted)]; outputs=[trace_ty]
37-
),
38-
1,
39-
)
40-
end
24+
trace_val = MLIR.IR.result(
25+
MLIR.Dialects.builtin.unrealized_conversion_cast(
26+
[TracedUtils.get_mlir_data(original_trace)]; outputs=[trace_ty]
27+
),
28+
1,
29+
)
4130

4231
selection_attr = MLIR.IR.Attribute[]
4332
for address in selection

src/probprog/MH.jl

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using ..Reactant: ConcreteRNumber, TracedRArray
22

33
function mh(
44
rng::AbstractRNG,
5-
original_trace::Union{ProbProgTrace,TracedRArray{UInt64,0}},
5+
original_trace,
66
f::Function,
77
args::Vararg{Any,Nargs};
88
selection::Selection,
@@ -17,24 +17,12 @@ function mh(
1717
MLIR.IR.context()::MLIR.API.MlirContext
1818
)::MLIR.IR.Type
1919

20-
if original_trace isa TracedRArray{UInt64,0}
21-
# Use MLIR data from previous iteration
22-
trace_val = MLIR.IR.result(
23-
MLIR.Dialects.builtin.unrealized_conversion_cast(
24-
[original_trace.mlir_data]; outputs=[trace_ty]
25-
),
26-
1,
27-
)
28-
else
29-
# First iteration: create constant from pointer
30-
promoted = to_trace_tensor(original_trace)
31-
trace_val = MLIR.IR.result(
32-
MLIR.Dialects.builtin.unrealized_conversion_cast(
33-
[TracedUtils.get_mlir_data(promoted)]; outputs=[trace_ty]
34-
),
35-
1,
36-
)
37-
end
20+
trace_val = MLIR.IR.result(
21+
MLIR.Dialects.builtin.unrealized_conversion_cast(
22+
[TracedUtils.get_mlir_data(original_trace)]; outputs=[trace_ty]
23+
),
24+
1,
25+
)
3826

3927
selection_attr = MLIR.IR.Attribute[]
4028
for address in selection

0 commit comments

Comments
 (0)