File tree Expand file tree Collapse file tree 2 files changed +14
-37
lines changed Expand file tree Collapse file tree 2 files changed +14
-37
lines changed Original file line number Diff line number Diff line change @@ -2,7 +2,7 @@ using ..Reactant: ConcreteRNumber, TracedRArray
22
33function hmc (
44 rng:: AbstractRNG ,
5- original_trace:: Union{ProbProgTrace,TracedRArray{UInt64,0}} ,
5+ original_trace,
66 f:: Function ,
77 args:: Vararg{Any,Nargs} ;
88 selection:: Selection ,
@@ -21,23 +21,12 @@ function hmc(
2121 MLIR. IR. context ():: MLIR.API.MlirContext
2222 ):: MLIR.IR.Type
2323
24- trace_val = if original_trace isa TracedRArray{UInt64,0 }
25- MLIR. IR. result (
26- MLIR. Dialects. builtin. unrealized_conversion_cast (
27- [original_trace. mlir_data]; outputs= [trace_ty]
28- ),
29- 1 ,
30- )
31- else
32- # First iteration: promote a ProbProgTrace to tensor<ui64>
33- promoted = to_trace_tensor (original_trace)
34- MLIR. IR. result (
35- MLIR. Dialects. builtin. unrealized_conversion_cast (
36- [TracedUtils. get_mlir_data (promoted)]; outputs= [trace_ty]
37- ),
38- 1 ,
39- )
40- end
24+ trace_val = MLIR. IR. result (
25+ MLIR. Dialects. builtin. unrealized_conversion_cast (
26+ [TracedUtils. get_mlir_data (original_trace)]; outputs= [trace_ty]
27+ ),
28+ 1 ,
29+ )
4130
4231 selection_attr = MLIR. IR. Attribute[]
4332 for address in selection
Original file line number Diff line number Diff line change @@ -2,7 +2,7 @@ using ..Reactant: ConcreteRNumber, TracedRArray
22
33function mh (
44 rng:: AbstractRNG ,
5- original_trace:: Union{ProbProgTrace,TracedRArray{UInt64,0}} ,
5+ original_trace,
66 f:: Function ,
77 args:: Vararg{Any,Nargs} ;
88 selection:: Selection ,
@@ -17,24 +17,12 @@ function mh(
1717 MLIR. IR. context ():: MLIR.API.MlirContext
1818 ):: MLIR.IR.Type
1919
20- if original_trace isa TracedRArray{UInt64,0 }
21- # Use MLIR data from previous iteration
22- trace_val = MLIR. IR. result (
23- MLIR. Dialects. builtin. unrealized_conversion_cast (
24- [original_trace. mlir_data]; outputs= [trace_ty]
25- ),
26- 1 ,
27- )
28- else
29- # First iteration: create constant from pointer
30- promoted = to_trace_tensor (original_trace)
31- trace_val = MLIR. IR. result (
32- MLIR. Dialects. builtin. unrealized_conversion_cast (
33- [TracedUtils. get_mlir_data (promoted)]; outputs= [trace_ty]
34- ),
35- 1 ,
36- )
37- end
20+ trace_val = MLIR. IR. result (
21+ MLIR. Dialects. builtin. unrealized_conversion_cast (
22+ [TracedUtils. get_mlir_data (original_trace)]; outputs= [trace_ty]
23+ ),
24+ 1 ,
25+ )
3826
3927 selection_attr = MLIR. IR. Attribute[]
4028 for address in selection
You can’t perform that action at this time.
0 commit comments