diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 0e73cff527..d1e1420db4 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -86,11 +86,13 @@ function set_reactant_abi( ) end + +current_interpreter = Ref{Enzyme.Compiler.Interpreter.EnzymeInterpreter{typeof(Reactant.set_reactant_abi)}}() @static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE struct ReactantCacheToken end function ReactantInterpreter(; world::UInt=Base.get_world_counter()) - return Enzyme.Compiler.Interpreter.EnzymeInterpreter( + current_interpreter[] = Enzyme.Compiler.Interpreter.EnzymeInterpreter( ReactantCacheToken(), REACTANT_METHOD_TABLE, world, @@ -108,7 +110,7 @@ else function ReactantInterpreter(; world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE ) - return Enzyme.Compiler.Interpreter.EnzymeInterpreter( + current_interpreter[] = Enzyme.Compiler.Interpreter.EnzymeInterpreter( REACTANT_CACHE, REACTANT_METHOD_TABLE, world, diff --git a/src/Precompile.jl b/src/Precompile.jl index c467bf7fd5..00e2bfa73d 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -49,7 +49,7 @@ end # Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947 function precompilation_supported() - return (VERSION >= v"1.11" || VERSION >= v"1.10.8") && (VERSION < v"1.12-") + return false && (VERSION >= v"1.11" || VERSION >= v"1.10.8") && (VERSION < v"1.12-") end if Reactant_jll.is_available() diff --git a/src/TracedRange.jl b/src/TracedRange.jl index 657fb4f711..7cfc37d910 100644 --- a/src/TracedRange.jl +++ b/src/TracedRange.jl @@ -177,9 +177,10 @@ function Base._reshape(parent::TracedUnitRange, dims::Dims) return Base.__reshape((parent, IndexStyle(parent)), dims) end -function (C::Base.Colon)(start::TracedRNumber{T}, stop::TracedRNumber{T}) where {T} +#=function (C::Base.Colon)(start::TracedRNumber{T}, stop::TracedRNumber{T}) where {T} return TracedUnitRange(start, stop) end +=# function (C::Base.Colon)(start::TracedRNumber{T}, stop::T) where {T} return C(start, TracedRNumber{T}(stop)) end diff --git a/src/auto_cf/AutoCF.jl b/src/auto_cf/AutoCF.jl new file mode 100644 index 0000000000..edbb73217e --- /dev/null +++ b/src/auto_cf/AutoCF.jl @@ -0,0 +1,6 @@ +include("debug_utils.jl") +include("new_inference.jl") +include("code_info_mut.jl") +include("code_ir_utils.jl") +include("mlir_utils.jl") +include("code_gen.jl") diff --git a/src/auto_cf/analysis.jl b/src/auto_cf/analysis.jl new file mode 100644 index 0000000000..74eeb08e7c --- /dev/null +++ b/src/auto_cf/analysis.jl @@ -0,0 +1,84 @@ +@enum UpgradeSlot NoUpgrade UpgradeLocally UpgradeDefinition UpgradeDefinitionGlobal EqualizeBranches + +@enum State Traced Upgraded Maybe NotTraced + +mutable struct ForStructure + accus::Tuple + header_bb::Int + latch_bb::Int + terminal_bb::Int + body_bbs::Set{Int} + state::State +end + +struct IfStructure + ssa_cond + header_bb::Int + terminal_bb::Int + true_bbs::Set{Int} + false_bbs::Set{Int} + owned_true_bbs::Set{Int} + owned_false_bbs::Set{Int} + legalize::Ref{Bool} #inform that the if traced GotoIfNot can pass type inference + unbalanced_slots::Set{Core.SlotNumber} +end + +mutable struct SlotAnalysis + slot_stmt_def::Vector{Integer} #0 for argument + slot_bb_usage::Vector{Set{Int}} +end + + +CFStructure = Union{IfStructure,ForStructure} +mutable struct Tree + node::Union{Nothing,Base.uniontypes(CFStructure)...} + children::Vector{Tree} + parent::Ref{Tree} +end + +Base.isempty(tree::Tree) = isnothing(tree.node) && length(tree.children) == 0 + +Base.show(io::IO, t::Tree) = begin + Base.print(io, '(') + Base.show(io, t.node) + Base.print(io, ',') + Base.show(io, t.children) + Base.print(io, ')') +end + +mutable struct Analysis + tree::Tree + domtree::Union{Nothing,Vector{CC.DomTreeNode}} + postdomtree::Union{Nothing,Vector{CC.DomTreeNode}} + slotanalysis::Union{Nothing,SlotAnalysis} + pending_tree::Union{Nothing,Tree} +end + + +#leak each argument to a global variable +macro lk(args...) + quote + $([:( + let val = $(esc(p)) + global $(esc(p)) = val + end + ) for p in args]...) + end +end + + + + + +MethodInstanceKey = Vector{Type} +function mi_key(mi::Core.MethodInstance) + return collect(Base.unwrap_unionall(mi.specTypes).parameters) +end +@kwdef struct MetaData + traced_tree_map::Dict{MethodInstanceKey,Tree} = Dict() +end + +meta = Ref(MetaData()) +function get_meta(_::Reactant.ReactantInterp)::MetaData + meta[] +end \ No newline at end of file diff --git a/src/auto_cf/code_gen.jl b/src/auto_cf/code_gen.jl new file mode 100644 index 0000000000..4abf751dc0 --- /dev/null +++ b/src/auto_cf/code_gen.jl @@ -0,0 +1,803 @@ +#TODO: remove this +returning_type(X) = X +get_traced_type(X) = X + +struct TUnitRange{T} + min::Union{T,Reactant.TracedRNumber{T}} + max::Union{T,Reactant.TracedRNumber{T}} +end + +struct TStepRange{T} + min::Union{T,Reactant.TracedRNumber{T}} + step::T #TODO:add support to traced step + max::Union{T,Reactant.TracedRNumber{T}} +end + +#Needed otherwise standard lib defined a more specialized method +function (::Colon)(min::Reactant.TracedRNumber{T}, max::Reactant.TracedRNumber{T}) where {T} + return TUnitRange(min, max) +end + +function (::Colon)( + min::Union{T,Reactant.TracedRNumber{T}}, max::Union{T,Reactant.TracedRNumber{T}} +) where {T} + return TUnitRange(min, max) +end +Base.first(a::TUnitRange) = a.min +Base.last(a::TUnitRange) = a.max +@noinline Base.iterate(i::TUnitRange{T}, _::Nothing=nothing) where {T} = + CC.inferencebarrier(i)::Union{Nothing,Tuple{Reactant.TracedRNumber{T},Nothing}} + +function (::Colon)( + min::Union{T,Reactant.TracedRNumber{T}}, + step::T, + max::Union{T,Reactant.TracedRNumber{T}}, +) where {T} + return TStepRange(min, step, max) +end +Base.first(a::TStepRange) = a.min +Base.last(a::TStepRange) = a.max +@noinline Base.iterate(i::TStepRange{T}, _::Nothing=nothing) where {T} = + CC.inferencebarrier(i)::Union{Nothing,Tuple{Reactant.TracedRNumber{T},Nothing}} + +#keep using the base iterate for upgraded loop. +Base.iterate(T::Type, args...) = CC.inferencebarrier(args)::T + +""" + Hidden + + struct to hide the default print of a type, use to show a CodeIR containing inlined CodeIR + #TODO: add parametric +""" +struct Hidden + value +end + +function Base.show(io::IO, x::Hidden) + return print(io, "<$(typeof(x.value))>") +end + +""" + juliair_to_mlir(ir::Core.Compiler.IRCode, args...) -> Vector + +Execute the `ir` and add a MLIR `return` operation to the traced `ir` return variables. Return all `ir` return variable +TODO: remove masked_traced +`args` must follow types restriction in `ir.argtypes`, otherwise completely break Julia +""" +@noinline function juliair_to_mlir(ir::Core.Compiler.IRCode, args...)::Tuple + @warn typeof.(args) + @warn ir.argtypes[2:end] + #Cannot use .<: -> dispatch to Reactant `materialize` + equal = length(args) == length(ir.argtypes[2:end]) + for (a, b) in zip(typeof.(args), ir.argtypes[2:end]) + equal || break + equal = a <: b + end + @assert equal "$(typeof.(args)) \n $(ir.argtypes[2:end])" + @warn ir + f = Core.OpaqueClosure(ir) + result = f(args...) + isnothing(result) && return () + result = result isa Tuple ? result : tuple(result) + return result +end + +@skip_rewrite_func juliair_to_mlir + +function remove_phi_node_for_body!(ir::CC.IRCode, f::ForStructure) + first_bb = min(f.body_bbs...) + traced_ssa = [] + type_traced_ssa = Type[] + for index in ir.cfg.blocks[first_bb].stmts + stmt = ir.stmts.stmt[index] + isnothing(stmt) && continue #phi node can be simplified during IR compact + stmt isa Core.PhiNode || break + ir.stmts.stmt[index] = stmt.values[1] + type = ir.stmts.type[index] + is_traced(type) || continue + push!(traced_ssa, stmt.values[1]) + push!(type_traced_ssa, type) + end + return traced_ssa, type_traced_ssa +end + +using Debugger +""" + apply_transformation!(ir::Core.Compiler.IRCode, if_::IfStructure) + Apply static Julia IR change to `ir` in order to tracing the if defined in `if_`. + Create a call to `jit_if_controlflow` which will during runtime trace the two branch of it following the two extracted IRCode. +""" +function apply_transformation!(ir::Core.Compiler.IRCode, if_::IfStructure) + (; + header_bb::Int, + terminal_bb::Int, + true_bbs::Set{Int}, + false_bbs::Set{Int}, + owned_true_bbs::Set{Int}, + owned_false_bbs::Set{Int}, + ) = if_ + true_phi_ssa = [] + false_phi_ssa = [] + if_returned_types = Type[] + phi_index = [] + #In the last block of if, collect all phi_values + for index in ir.cfg.blocks[terminal_bb].stmts + ir.stmts.stmt[index] isa Core.PhiNode || break + push!(phi_index, index) + phi = ir.stmts.stmt[index] + phi_type::Type = ir.stmts.type[index] + if_returned_type::Union{Type, Nothing} = returning_type(phi_type) #TODO: deal with promotion here + if_returned_type isa Nothing && error("transformation failed") + push!(if_returned_types, if_returned_type) + add_phi_value!(true_phi_ssa, phi, true_bbs,header_bb) + add_phi_value!(false_phi_ssa, phi, false_bbs,header_bb) + end + #Debugger.@bp + #map the old argument with the new ones + new_args_dict = Dict() + @warn "r1" ir true_bbs new_args_dict true_phi_ssa + r1 = extract_multiple_block_ir(ir, true_bbs, new_args_dict, true_phi_ssa) + clear_block_ir!(ir, owned_true_bbs) + + @warn "r2" ir false_bbs new_args_dict false_phi_ssa + r2 = extract_multiple_block_ir(ir, false_bbs, new_args_dict, false_phi_ssa) + clear_block_ir!(ir, owned_false_bbs) + + #common arguments for both branch + (value, new_args_v) = vec_args(ir, new_args_dict) + @lk new_args_dict + + r1 = finish(r1, new_args_v) + r2 = finish(r2, new_args_v) + + @warn "r1/r2" r1 r2 + + #remove MethodInstance name (needed for OpaqueClosure) + new_args_v = new_args_v[2:end] + + cond = cond_ssa(ir, header_bb) + owned_bbs = union(owned_true_bbs, owned_false_bbs) + #Mutate IR + #replace GotoIfNot -> GotoNode + #TODO: can cond be defined before goto? + ssa_goto = terminator_index(ir, header_bb) + change_stmt!( + ir, terminator_index(ir, max(owned_bbs...)), Core.GotoNode(terminal_bb), Any + ) + change_stmt!(ir, ssa_goto, nothing, Nothing) + change_stmt!(ir, ssa_goto - 1, nothing, Nothing) + + #PhiNodes simplifications: + n_result = length(phi_index) + + new_phi = [] + for phi_i in phi_index + removed_index = [] + phi = ir.stmts.stmt[phi_i] + for (i, edge) in enumerate(phi.edges) + (edge in owned_bbs || edge == header_bb) || continue + push!(removed_index, i) + end + push!( + new_phi, + length(phi.edges) - length(removed_index) == 0 ? nothing : removed_index, + ) + end + + sign = (Reactant.TracedRNumber{Bool}, Hidden, Hidden, Int, new_args_v...) + @error sign + mi = method_instance(jit_if_controlflow, sign, current_interpreter[].world) + isnothing(mi) && error("invalid Method Instance") + + @lk r1 r2 new_phi + @assert(!isnothing(cond)) + all_args = (cond, Hidden(r1), Hidden(r2), length(true_phi_ssa), value...) + @lk all_args sign + expr = Expr(:invoke, mi, GlobalRef(@__MODULE__, :jit_if_controlflow), all_args...) + + #all phi nodes are replaced and return one result: special case: the if can be created in the final block + if all((==).(new_phi, nothing)) && n_result == 1 + change_stmt!( + ir, + first(phi_index), + expr, + get_traced_type(returning_type(ir.stmts.type[first(phi_index)])), + ) + @goto out + end + + if_ssa = if n_result == 1 + ni = Core.Compiler.NewInstruction( + expr, get_traced_type(returning_type(only(if_returned_types))) + ) + if_ssa = Core.Compiler.insert_node!(ir, ssa_goto, ni, false) + else + tuple = Core.Compiler.NewInstruction(expr, Tuple{if_returned_types...}) + Core.Compiler.insert_node!(ir, ssa_goto, tuple, false) + end + + for (i, removed_index_phi) in enumerate(new_phi) + if isnothing(removed_index_phi) + ir.stmts.stmt[phi_index[i]] = if n_result == 1 + if_ssa + else + Expr(:call, Core.GlobalRef(Base, :getindex), if_ssa, i) + end + else + current_phi = ir.stmts.stmt[phi_index[i]] + isempty(removed_index_phi) && continue + deleteat!(current_phi.edges, removed_index_phi) + deleteat!(current_phi.values, removed_index_phi) + push!(current_phi.edges, header_bb) + #modify phi branch: in the case of several result, get result_i in if definition block + if n_result == 1 + push!(current_phi.values, if_ssa) + else + expr = Expr(:call, Core.GlobalRef(Base, :getindex), if_ssa, i) + ni = Core.Compiler.NewInstruction(expr, Tuple{if_returned_types...}) + result_i = Core.Compiler.insert_node!(ir, ssa_goto, ni, false) + push!(current_phi.values, result_i) + end + end + end + @label out + return ir +end + +function runtime_inner_type(e::Union{Reactant.RArray,Reactant.RNumber}) + return Reactant.MLIR.IR.type(e.mlir_data) +end +runtime_inner_type(e) = typeof(e) + +Base.getindex(::Tuple{}, ::Tuple{}) = () + +""" + jit_if_controlflow(cond::Reactant.TracedRNumber{Bool}, true_b::Core.Compiler.IRCode, false_b::Core.Compiler.IRCode, args...) -> Type + +During runtime, create an if MLIR operation from two branches `true_b` `false_b` Julia IRCode using the arguments `args`. +Return either a traced value or a tuple of traced values. + +""" +@noinline function jit_if_controlflow( + cond::Reactant.TracedRNumber{Bool}, r1, r2, n_result, args... +) + tmp_if_op = Reactant.MLIR.Dialects.stablehlo.if_( + cond.mlir_data; + true_branch=Reactant.MLIR.IR.Region(), + false_branch=Reactant.MLIR.IR.Region(), + result_0=[Reactant.MLIR.IR.Type(Nothing)], + ) + + b1 = Reactant.MLIR.IR.Block() + push!(Reactant.MLIR.IR.region(tmp_if_op, 1), b1) + Reactant.MLIR.IR.activate!(b1) + local_args_r1 = deepcopy.(args) + before_r1 = get_mlir_pointer_or_nothing.(local_args_r1) + tr1 = !isnothing(r1.value) ? juliair_to_mlir(r1.value, local_args_r1...) : () + tr1 = upgrade.(tr1) + after_r1 = get_mlir_pointer_or_nothing.(local_args_r1) + masked_muted_r1 = before_r1 .!== after_r1 + Reactant.MLIR.IR.deactivate!(b1) + + b2 = Reactant.MLIR.IR.Block() + push!(Reactant.MLIR.IR.region(tmp_if_op, 2), b2) + + Reactant.MLIR.IR.activate!(b2) + local_args_r2 = deepcopy.(args) + before_r2 = get_mlir_pointer_or_nothing.(local_args_r2) + tr2 = !isnothing(r2.value) ? juliair_to_mlir(r2.value, local_args_r2...) : () + tr2 = upgrade.(tr2) + after_r2 = get_mlir_pointer_or_nothing.(local_args_r2) + masked_muted_r2 = before_r2 .!== after_r2 + Reactant.MLIR.IR.deactivate!(b2) + + t1 = typeof.(tr1) + t2 = typeof.(tr2) + @lk t1 t2 + @error t1 t2 + #Assume results types are equal now: TODO: can be probably be relaxed by promoting types (need change to `juliair_to_mlir` and static IRCode Analysis) + @assert t1 == t2 "each branch $t1 $t2 must have the same type" + + #TODO: select special case + + @lk before_r1 before_r2 + @lk args local_args_r1 local_args_r2 + both_mut = collect((&).(masked_muted_r1, masked_muted_r2)) + masked_unique_muted_r1 = collect((&).(masked_muted_r1, (!).(both_mut))) + masked_unique_muted_r2 = collect((&).(masked_muted_r2, (!).(both_mut))) + @lk both_mut masked_unique_muted_r1 masked_unique_muted_r2 masked_muted_r1 masked_muted_r2 + tr1_muted = ( + local_args_r1[masked_unique_muted_r1]..., upgrade.(args[masked_unique_muted_r2])... + ) + tr2_muted = ( + upgrade.(args[masked_unique_muted_r1])..., local_args_r2[masked_unique_muted_r2]... + ) + @lk tr1_muted tr2_muted + + arg1 = (tr1..., local_args_r1[both_mut]..., tr1_muted...) + if !isempty(arg1) + Reactant.MLIR.IR.activate!(b1) + #TODO: promotion here + Reactant.Ops.return_(arg1...) + Reactant.MLIR.IR.deactivate!(b1) + end + + arg2 = (tr2..., local_args_r2[both_mut]..., tr2_muted...) + if !isempty(arg2) + Reactant.MLIR.IR.activate!(b2) + #TODO: promotion here + Reactant.Ops.return_(arg2...) + Reactant.MLIR.IR.deactivate!(b2) + end + + return_types = Reactant.MLIR.IR.type.(getfield.(tr1, :mlir_data)) + mut_types = Reactant.MLIR.IR.type.(getfield.(local_args_r1[both_mut], :mlir_data)) + mut_types2 = Reactant.MLIR.IR.type.(getfield.(tr1_muted, :mlir_data)) + + if_op = Reactant.MLIR.Dialects.stablehlo.if_( + cond.mlir_data; + true_branch=Reactant.MLIR.IR.Region(), + false_branch=Reactant.MLIR.IR.Region(), + result_0=Reactant.MLIR.IR.Type[return_types..., mut_types..., mut_types2...], + ) + Reactant.MLIR.API.mlirRegionTakeBody( + Reactant.MLIR.IR.region(if_op, 1), Reactant.MLIR.IR.region(tmp_if_op, 1) + ) + Reactant.MLIR.API.mlirRegionTakeBody( + Reactant.MLIR.IR.region(if_op, 2), Reactant.MLIR.IR.region(tmp_if_op, 2) + ) + + results = Vector(undef, length(t1)) + for (i, e) in enumerate(tr1) + traced = deepcopy(e) + traced.mlir_data = Reactant.MLIR.IR.result(if_op, i) #TODO: setmlirdata + results[i] = traced + end + + @lk if_op + + arg_offset = length(t1) + for (i, index) in enumerate(findall((|).(masked_muted_r1, masked_muted_r2))) + Reactant.TracedUtils.set_mlir_data!( + args[index], Reactant.MLIR.IR.result(if_op, arg_offset + i) + ) + end + + Reactant.MLIR.API.mlirOperationDestroy(tmp_if_op.operation) + + #TODO: add a runtime type check here using static analysis + return length(results) == 1 ? only(results) : Tuple(results) +end + +#remove iterator usage in JuliaIR and keep branch +function remove_iterator(ir::CC.IRCode, bb::Int) + terminator_pos = terminator_index(ir.cfg, bb) + cond = ir.stmts.stmt[terminator_pos].cond + cond isa Core.SSAValue || return nothing + iterator_index = cond.id - 2 + iterator_expr = ir.stmts.stmt[iterator_index] + @assert iterator_expr isa Expr && + iterator_expr.head == :call && + iterator_expr.args[1] == GlobalRef(Base, :iterate) + + iterator_def = iterator_expr.args[end] + for i in iterator_index:(iterator_index + 2) + change_stmt!(ir, i, nothing, Nothing) + end + return iterator_def +end + +function list_phi_nodes_values(ir::CC.IRCode, in_bb::Int32, phi_bb::Int32) + r = [] + for index in ir.cfg.blocks[in_bb].stmts + stmt = ir.stmts.stmt[index] + isnothing(stmt) && continue #phi node can be simplified during IR compact + stmt isa Core.PhiNode || break + index_phi = findfirst(x -> x == phi_bb, stmt.edges) + isnothing(index_phi) && continue + push!(r, stmt.values[index_phi]) + end + return r +end + +@skip_rewrite_func jit_if_controlflow + +function apply_transformation!(ir::CC.IRCode, f::ForStructure) + f.state == Maybe && return nothing + body_phi_ssa = list_phi_nodes_values(ir, Int32(min(f.body_bbs...)), Int32(f.header_bb)) + terminal_phi_ssa = list_phi_nodes_values(ir, Int32(f.terminal_bb), Int32(f.header_bb)) + #check terminal block Phi nodes and find the incumulators by doing the substraction between terminal body and first body block phi nodes + accumulars_mask = Vector() + for ssa in terminal_phi_ssa + push!(accumulars_mask, ssa in body_phi_ssa) + end + + new_args_dict = Dict() + #TODO: rewrite this: to use terminal_phi_ssa directly + (traced_ssa_for_bodies, traced_ssa_for_bodies_types) = remove_phi_node_for_body!(ir, f) + + ir_back = CC.copy(ir) + @lk ir_back + #iteration to reenter loop + remove_iterator(ir, max(f.body_bbs...)) + + last_bb = max(f.body_bbs...) + results = [] + for index in ir.cfg.blocks[f.terminal_bb].stmts + stmt = ir.stmts.stmt[index] + stmt isa Core.PhiNode || break + for (e_index, bb) in enumerate(stmt.edges) + bb == last_bb || continue + push!(results, stmt.values[e_index]) + end + end + body_bbs = f.body_bbs + @warn ir body_bbs new_args_dict results traced_ssa_for_bodies traced_ssa_for_bodies_types + @lk ir body_bbs new_args_dict results traced_ssa_for_bodies traced_ssa_for_bodies_types + #TODO: replace result with terminal_phi_ssa + loop_body = extract_multiple_block_ir(ir, f.body_bbs, new_args_dict, results).ir + #value doesn't contain the function name unlike new_args_v + (value, new_args_v) = vec_args(ir, new_args_dict) + iterator_index = 0 + for (i, t) in enumerate(new_args_v) + (t isa Union && Nothing in Base.uniontypes(t)) || continue + iterator_index = i - 1 + break + end + #iteration to enter the loop + iterator_def = remove_iterator(ir, f.header_bb) + #fix cfg + + change_stmt!(ir, terminator_index(ir, f.header_bb), Core.GotoNode(f.terminal_bb), Any) + change_stmt!(ir, terminator_index(ir, last_bb), Core.GotoNode(f.terminal_bb), Any) + clear_block_ir!(ir, f.body_bbs) + + t = if iterator_def isa QuoteNode #constant iterator: + #IMPORTANT: object must be copied: QuoteNode.value cannot be reused in Opaque Closure + iterator_def = copy(iterator_def.value) + typeof(iterator_def) + else + ir.stmts.type[iterator_def.id] + end + + @lk value new_args_v terminal_phi_ssa t + while_output_type = (typeof_ir(ir, ssa) for ssa in terminal_phi_ssa) + #first element in new_args_v/ value is the iterator first step: only the iterator definition is needed + sign = (t, Hidden, Int, Vector{Bool}, while_output_type..., new_args_v[2:end]...) + @lk sign + mi = method_instance( + jit_loop_controlflow, CC.widenconst.(sign), current_interpreter[].world + ) + isnothing(mi) && error("invalid Method Instance") + expr = Expr( + :invoke, + mi, + GlobalRef(@__MODULE__, :jit_loop_controlflow), + iterator_def, + Hidden(loop_body), + iterator_index, + accumulars_mask, + terminal_phi_ssa..., + value..., + ) + @warn expr + phi_index = [] + #In the last block of for, collect all phi_values + for index in ir.cfg.blocks[f.terminal_bb].stmts + ir.stmts.stmt[index] isa Core.PhiNode || break + push!(phi_index, index) + end + + if length(phi_index) == 0 + CC.insert_node!( + ir, + CC.SSAValue(start_index(ir, f.terminal_bb)), + Core.Compiler.NewInstruction(expr, Any), + false, + ) + elseif length(phi_index) == 1 + phi = only(phi_index) + change_stmt!(ir, phi, expr, returning_type(ir.stmts.type[phi])) + else + while_ssa = Core.SSAValue(terminator_index(ir, f.header_bb) - 1) + change_stmt!(ir, while_ssa.id, expr, Tuple{while_output_type...}) + for (i, index) in enumerate(phi_index) + ir.stmts.stmt[index] = Expr( + :call, Core.GlobalRef(Base, :getindex), while_ssa, i + ) + end + end +end + +function get_mlir_pointer_or_nothing(x::Union{Reactant.TracedRNumber,Reactant.TracedRArray}) + return Reactant.TracedUtils.get_mlir_data(x).value +end + +get_mlir_pointer_or_nothing(_) = nothing + +#iterator for_body iterator_type n_init traced_ssa_for_bodies args +@noinline function jit_loop_controlflow( + iterator, for_body::Hidden, iterator_index::Int, accu_mask::Vector{Bool}, args_full... +) + #only support UnitRange atm + (start, stop, iterator_begin, iter_step) = + if iterator isa Union{Base.OneTo,UnitRange,TUnitRange,StepRange,TStepRange} + start = first(iterator) + stop = last(iterator) + iter_step = iterator isa Union{StepRange,TStepRange} ? iterator.step : 1 + (start, stop, Reactant.Ops.constant(start), iter_step) + else + error("unsupported type $(typeof(iterator))") + end + + start = first(iterator) + stop = last(iterator) + @lk start + iterator_ = if is_traced(typeof(start)) + start + else + Reactant.TracedRNumber{typeof(start)}((), nothing) + end + n_accu = length(accu_mask) + @lk n_accu args_full accu_mask iterator_index + accus = args_full[1:n_accu] + julia_use_iter = iterator_index != 0 + args = args_full[(n_accu + 1):end] + @lk args accus + tmp_while_op = Reactant.MLIR.Dialects.stablehlo.while_( + Reactant.MLIR.IR.Value[]; + cond=Reactant.MLIR.IR.Region(), + body=Reactant.MLIR.IR.Region(), + result_0=Reactant.MLIR.IR.Type[Reactant.Ops.mlir_type.(accus)...], + ) + + mlir_loop_args = Reactant.MLIR.IR.Type[ + Reactant.Ops.mlir_type(iterator_), Reactant.Ops.mlir_type.(accus)... + ] + cond = Reactant.MLIR.IR.Block( + mlir_loop_args, [Reactant.MLIR.IR.Location() for _ in mlir_loop_args] + ) + push!(Reactant.MLIR.IR.region(tmp_while_op, 1), cond) + + @lk cond mlir_loop_args + + Reactant.MLIR.IR.activate!(cond) + Reactant.Ops.activate_constant_context!(cond) + t1 = deepcopy(iterator_) + Reactant.TracedUtils.set_mlir_data!(t1, Reactant.MLIR.IR.argument(cond, 1)) + r = iter_step > 0 ? t1 <= stop : t1 >= stop + Reactant.Ops.return_(r) + Reactant.Ops.deactivate_constant_context!(cond) + Reactant.MLIR.IR.deactivate!(cond) + + body = Reactant.MLIR.IR.Block( + mlir_loop_args, [Reactant.MLIR.IR.Location() for _ in mlir_loop_args] + ) + push!(Reactant.MLIR.IR.region(tmp_while_op, 2), body) + + for (i, arg) in enumerate(accus) + arg_ = deepcopy(arg) + Reactant.TracedUtils.set_mlir_data!(arg_, Reactant.MLIR.IR.argument(body, i + 1)) + end + + #TODO: add try finally + Reactant.MLIR.IR.activate!(body) + Reactant.Ops.activate_constant_context!(body) + iter_reactant = deepcopy(iterator_) + Reactant.TracedUtils.set_mlir_data!(iter_reactant, Reactant.MLIR.IR.argument(body, 1)) + + @lk iter_reactant args for_body + + block_accus = [] + for j in eachindex(args) + if args[j] isa Union{Reactant.TracedRNumber,Reactant.TracedRArray} + for k in eachindex(accus) + (isnothing(args[j]) || isnothing(accus[k])) && continue + args[j].mlir_data == accus[k].mlir_data || continue + tmp = Reactant.TracedUtils.set_mlir_data!( + deepcopy(args[j]), Reactant.MLIR.IR.argument(body, 1 + k) + ) + push!(block_accus, tmp) + @goto break2 + end + end + push!(block_accus, args[j]) + @label break2 + end + + pointer_before = get_mlir_pointer_or_nothing.(args) + + if iterator_index != 0 + block_accus[iterator_index] = (iter_reactant, nothing) + end + + @lk block_accus + + t = juliair_to_mlir(for_body.value, block_accus...) + @lk t + #we use a local defined variable inside of for outside: the argument must be added to while operation (cond and body) + + pointer_after = get_mlir_pointer_or_nothing.(args) + + muted_mask = collect(pointer_before .!= pointer_after) + args_muted = args[muted_mask] + + for (am, old_value) in zip(args_muted, pointer_before[muted_mask]) + type = Reactant.MLIR.IR.type(am.mlir_data) + Reactant.MLIR.IR.push_argument!(cond, type) + new_value = Reactant.MLIR.IR.push_argument!(body, type) + @warn "changed $(Reactant.MLIR.IR.Value(old_value)) to $new_value" + @lk new_value + change_value!(Reactant.MLIR.IR.Value(old_value), new_value, body) + end + + @lk pointer_before pointer_after t body args_muted accus + + iter_next = iter_step > 0 ? iter_reactant + iter_step : iter_reactant - abs(iter_step) + Reactant.Ops.return_(iter_next, t..., args_muted...) + + Reactant.MLIR.IR.deactivate!(body) + Reactant.Ops.deactivate_constant_context!(body) + + @lk iterator_begin + + while_op = Reactant.MLIR.Dialects.stablehlo.while_( + Reactant.MLIR.IR.Value[ + Reactant.TracedUtils.get_mlir_data(iterator_begin), + Reactant.TracedUtils.get_mlir_data.(accus)..., + Reactant.MLIR.IR.Value.(pointer_before[muted_mask])..., + ]; + cond=Reactant.MLIR.IR.Region(), + body=Reactant.MLIR.IR.Region(), + result_0=Reactant.MLIR.IR.Type[ + Reactant.Ops.mlir_type(iterator_begin), + Reactant.Ops.mlir_type.(accus)..., + Reactant.Ops.mlir_type.(args_muted)..., + ], + ) + + Reactant.MLIR.API.mlirRegionTakeBody( + Reactant.MLIR.IR.region(while_op, 1), Reactant.MLIR.IR.region(tmp_while_op, 1) + ) + Reactant.MLIR.API.mlirRegionTakeBody( + Reactant.MLIR.IR.region(while_op, 2), Reactant.MLIR.IR.region(tmp_while_op, 2) + ) + + init_mlir_result_offset = max(1, julia_use_iter ? 1 : 0) #TODO: suspicions probably min + n = init_mlir_result_offset + length(accus) + for (i, muted) in enumerate(args_muted) + Reactant.TracedUtils.set_mlir_data!(muted, Reactant.MLIR.IR.result(while_op, n + i)) + end + + Reactant.MLIR.API.mlirOperationDestroy(tmp_while_op.operation) + + results = [] + for (i, accu) in enumerate(accus) + r_i = deepcopy(accu) #TODO: is this needed? + Reactant.TracedUtils.set_mlir_data!( + r_i, Reactant.MLIR.IR.result(while_op, i + init_mlir_result_offset) + ) + @info "r_i" r_i + push!(results, r_i) + end + + #loop can contain non accus which are returned + # x = 5 + # for i in 1:10 + # x = 2 + # end + # x + + return length(results) == 1 ? only(results) : Tuple(results) +end + +@skip_rewrite_func jit_loop_controlflow + +function post_order(tree::Tree) + v = [] + for c in tree.children + push!(v, post_order(c)...) + end + return push!(v, tree.node) +end + +""" + control_flow_transform!(an::Analysis, ir::Core.Compiler.IRCode) -> Core.Compiler.IRCode + apply changes to traced control flow, `ir` argument is not valid anymore +""" +function control_flow_transform!(tree::Tree, ir::CC.IRCode)::CC.IRCode + for node in post_order(tree)[1:(end - 1)] + apply_transformation!(ir, node) + ir = CC.compact!(ir, false) + end + return CC.compact!(ir, true) +end + +#= + analysis_reassign_block_id!(an::Analysis, ir::Core.IRCode, src::Core.CodeInfo) + slot2reg can change type infered CodeInfo CFG by removing non-reachable block, + ControlFlow analysis use blocks information and must be shifted. +=# +function analysis_reassign_block_id!(tree::Tree, ir::CC.IRCode, src::CC.CodeInfo) + isempty(tree) && return false + cfg = CC.compute_basic_blocks(src.code) + length(ir.cfg.blocks) == length(cfg.blocks) && return false + @info "rewrite analysis blocks" + new_block_map = [] + i = 0 + for block in cfg.blocks + unreacheable_block = all(x -> src.ssavaluetypes[x] === Union{}, block.stmts) + i = unreacheable_block ? i : i + 1 + push!(new_block_map, i) + end + @info new_block_map + function reassign_tree!(s::Set{Int}) + n = [new_block_map[i] for i in s] + empty!(s) + return push!(s, n...) + end + + function reassign_tree!(is::IfStructure) + is.header_bb = new_block_map[is.header_bb] + is.terminal_bb = new_block_map[is.terminal_bb] + reassign_tree!(is.true_bbs) + reassign_tree!(is.false_bbs) + reassign_tree!(is.owned_true_bbs) + return reassign_tree!(is.owned_false_bbs) + end + + function reassign_tree!(fs::ForStructure) + fs.header_bb = new_block_map[fs.header_bb] + fs.latch_bb = new_block_map[fs.latch_bb] + fs.terminal_bb = new_block_map[fs.terminal_bb] + return reassign_tree!(fs.body_bbs) + end + + function reassign_tree!(t::Tree) + isnothing(t.node) || reassign_tree!(t.node) + for c in t.children + reassign_tree!(c) + end + end + reassign_tree!(tree) + return true +end + +function run_passes_ipo_safe_auto_cf( + ci::CC.CodeInfo, + sv::CC.OptimizationState, + caller::CC.InferenceResult, + tree::Tree, + optimize_until=nothing, # run all passes by default +) + __stage__ = 0 # used by @pass + # NOTE: The pass name MUST be unique for `optimize_until::AbstractString` to work + CC.@pass "convert" ir = CC.convert_to_ircode(ci, sv) + CC.@pass "slot2reg" ir = CC.slot2reg(ir, ci, sv) + + analysis_reassign_block_id!(tree, ir, ci) + # TODO: Domsorting can produce an updated domtree - no need to recompute here + CC.@pass "compact 1" ir = CC.compact!(ir) + @error "before" ir + bir = CC.copy(ir) + @lk bir tree + ir = control_flow_transform!(tree, ir) + @error "after" ir + CC.@pass "Inlining" ir = CC.ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds) + # @timeit "verify 2" verify_ir(ir) + CC.@pass "compact 2" ir = CC.compact!(ir) + CC.@pass "SROA" ir = CC.sroa_pass!(ir, sv.inlining) + @info sv.linfo + CC.@pass "ADCE" (ir, made_changes) = CC.adce_pass!(ir, sv.inlining) + if made_changes + CC.@pass "compact 3" ir = CC.compact!(ir, true) + end + if CC.is_asserts() + CC.@timeit "verify 3" begin + CC.verify_ir(ir, true, false, CC.optimizer_lattice(sv.inlining.interp)) + CC.verify_linetable(ir.linetable) + end + end + CC.@label __done__ # used by @pass + return ir +end \ No newline at end of file diff --git a/src/auto_cf/code_info_mut.jl b/src/auto_cf/code_info_mut.jl new file mode 100644 index 0000000000..0eab493b14 --- /dev/null +++ b/src/auto_cf/code_info_mut.jl @@ -0,0 +1,106 @@ +struct ShiftedSSA + e::Int +end + +struct ShiftedCF + e::Int +end + +function offset_stmt!(stmt, index, next_bb = true) + if stmt isa Expr + Expr( + stmt.head, (offset_stmt!(a, index) for a in stmt.args)...) + elseif stmt isa Core.ReturnNode + Core.ReturnNode(offset_stmt!(stmt.val, index)) + elseif stmt isa Core.SSAValue + Core.SSAValue(offset_stmt!(ShiftedSSA(stmt.id), index)) + elseif stmt isa Core.GotoIfNot + Core.GotoIfNot(offset_stmt!(stmt.cond, index), offset_stmt!(ShiftedCF(stmt.dest), index, next_bb)) + elseif stmt isa Core.GotoNode + Core.GotoNode(offset_stmt!(ShiftedCF(stmt.label), index, next_bb)) + elseif stmt isa ShiftedSSA + stmt.e + (stmt.e < index ? 0 : 1) + elseif stmt isa ShiftedCF + stmt.e + (stmt.e < index + (next_bb ? 1 : 0) ? 0 : 1) + else + stmt + end +end + +#insert stmt in frame after index +function add_instruction!(frame, index, stmt; type=CC.NotFound(), next_bb = true) + add_instruction!(frame.src, index, stmt; type, next_bb) + frame.ssavalue_uses = CC.find_ssavalue_uses(frame.src.code, length(frame.src.code)) #TODO: more fine graine change here + insert!(frame.stmt_info, index + 1, CC.NoCallInfo()) + insert!(frame.stmt_edges, index + 1, nothing) + insert!(frame.handler_at, index + 1, (0,0)) + frame.cfg = CC.compute_basic_blocks(frame.src.code) + Core.SSAValue(index + 1) +end + +function modify_instruction!(frame, index, stmt) + frame.src.code[index] = stmt + frame.ssavalue_uses = CC.find_ssavalue_uses(frame.src.code, length(frame.src.code)) #TODO: refine this +end + +""" + add_instruction!(ir::CC.CodeInfo, index, stmt; next_bb) + +""" +function add_instruction!(ir::CC.CodeInfo, index, stmt; type=CC.NotFound(), next_bb=true) + for (i, c) in enumerate(ir.code) + ir.code[i] = offset_stmt!(c, index + 1, next_bb) + end + insert!(ir.code, index + 1, stmt) + insert!(ir.codelocs, index + 1, 0) + insert!(ir.ssaflags, index + 1, 0x00000000) + if ir.ssavaluetypes isa Int + ir.ssavaluetypes = ir.ssavaluetypes + 1 + else + insert!(ir.ssavaluetypes, index + 1, type) + end +end + + +function create_slot!(ir::CC.CodeInfo)::Core.SlotNumber + push!(ir.slotflags, 0x00) + push!(ir.slotnames, Symbol("")) + Core.SlotNumber(length(ir.slotflags)) +end + +function create_slot!(frame)::Core.SlotNumber + push!(frame.slottypes, Union{}) + for s in frame.bb_vartables + isnothing(s) && continue + push!(s, CC.VarState(Union{}, true)) + end + create_slot!(frame.src) +end + +add_slot_change!(ir::CC.CodeInfo, index, old_slot::Int) = add_slot_change!(ir, index, Core.SlotNumber(old_slot)) + +function add_slot_change!(ir::CC.CodeInfo, index, old_slot::Core.SlotNumber) + push!(ir.slotflags, 0x00) + push!(ir.slotnames, Symbol("")) + new_slot = Core.SlotNumber(length(ir.slotflags)) + add_instruction!(frame, index, Expr(:(=), new_slot, Expr(:call, GlobalRef(@__MODULE__, :upgrade), old_slot))) + update_ir_new_slot(ir, index, old_slot, new_slot) +end + +function update_ir_new_slot(ir, index, old_slot, new_slot) + for i in index+2:length(ir.code) #TODO: probably need to refine this + ir.code[i] = replace_slot_stmt(ir.code[i], old_slot, new_slot) + end +end + +function replace_slot_stmt(stmt, old_slot, new_slot) + if stmt isa Core.NewvarNode + stmt + elseif stmt isa Expr + Expr(stmt.head, (replace_slot_stmt(e, old_slot, new_slot) for e in stmt.args)...) + elseif stmt isa Core.SlotNumber + stmt == old_slot ? new_slot : stmt + else + stmt + end +end \ No newline at end of file diff --git a/src/auto_cf/code_ir_utils.jl b/src/auto_cf/code_ir_utils.jl new file mode 100644 index 0000000000..1774dbff4f --- /dev/null +++ b/src/auto_cf/code_ir_utils.jl @@ -0,0 +1,374 @@ +""" + method_instance(f::Function, sign::Tuple{Vararg{Type}}, world) -> Union{Base.MethodInstance, Nothing} + +Same as `Base.method_instance` except it can work in generated function such as `call_with_reactant` +""" +function method_instance(f::Function, sign::Tuple{Vararg{Type}}, world) + tt = Base.signature_type(f, sign) + match, _ = Core.Compiler._findsup(tt, nothing, world) + isnothing(match) && return nothing + mi = Core.Compiler.specialize_method(match) + return mi +end + +""" + change_stmt!(ir::Core.Compiler.IRCode, ssa::Int, stmt, return_type::Type) -> Core.Compiler.Instruction + +Change the `ir` at position `ssa` by the statement `stmt` with a `return_type` +TODO: when stmt is the terminator: Goto -> nothing : must update cfg +""" +function change_stmt!(ir::Core.Compiler.IRCode, ssa::Int, stmt, return_type::Type) + return Core.Compiler.inst_from_newinst!( + ir[Core.SSAValue(ssa)], + Core.Compiler.NewInstruction(stmt, return_type), + Int32(0), + UInt32(0), + ) +end + +""" + change_stmt!(ir::Core.Compiler.IRCode, ssa::Int, goto::Core.GotoNode, return_type::Type) -> Core.Compiler.Instruction +Specialization of [`change_stmt!`](@ref) for `Core.GotoNode` to deal with control flow graph changes +""" +function change_stmt!( + ir::Core.Compiler.IRCode, ssa::Int, goto::Core.GotoNode, return_type::Type +) + bb::Int64 = Core.Compiler.block_for_inst(ir, ssa) + succs = ir.cfg.blocks[bb].succs + empty!(succs) + push!(succs, goto.label) + push!(ir.cfg.blocks[goto.label].preds, bb) + @invoke change_stmt!(ir, ssa, goto::Any, return_type) +end + +""" + clear_block_ir!(ir::Core.Compiler.IRCode, blocks::Set{Int}) + Replace in BB `blocks` of `ir` each instruction by nothing +""" +function clear_block_ir!(ir::Core.Compiler.IRCode, blocks::Set{Int}) + for block in blocks + stmt_range::Core.Compiler.StmtRange = ir.cfg.blocks[block].stmts + (f, l) = (first(stmt_range), last(stmt_range)) + for i in f:l + change_stmt!(ir, i, nothing, Nothing) + end + end +end + +""" + type_from_ssa(ir::Core.Compiler.IRCode, argtypes::Vector, v::Vector)::Vector + For each stmt in `v` in `ir` get its type +""" +function type_from_ssa(ir::Core.Compiler.IRCode, argtypes::Vector, v) + cir = ir + @lk cir + return [ + begin + if e isa Core.SSAValue + ir.stmts.type[e.id] + elseif e isa Core.Argument + argtypes[e.n] + else + typeof(e) + end + end for e in v + ] +end + +""" + apply_map(array, block_map::Dict)::Vector + For each element of `array`, get the value associated in the dictionnary `block_map` +""" +function apply_map(array, block_map) + return [block_map[a] for a in array if haskey(block_map, a)] +end + +""" + new_cfg(ir, to_extract::Vector, block_map)::Core.Compiler.CFG + Get the new CFG of `ir` after the extraction of `to_extract` blocks +""" +function new_cfg(ir, to_extract, block_map) + n = 1 + bbs = Core.Compiler.BasicBlock[] + index = Int64[] + for b in to_extract + bb = ir.cfg.blocks[b] + (; start, stop) = bb.stmts + diff = stop - start + push!( + bbs, + Core.Compiler.BasicBlock( + Core.Compiler.StmtRange(n, diff + n), + apply_map(bb.preds, block_map), + apply_map(bb.succs, block_map), + ), + ) + n += diff + 1 + push!(index, n) + end + return Core.Compiler.CFG(bbs, index) +end + +""" + WipExtracting + struct used for an extracted IRCode which is not fully constructed +""" +struct WipExtracting + ir::Core.Compiler.IRCode +end + +""" + is_a_terminator(stmt) + Check if `stmt` is a terminator +""" +function is_a_terminator(stmt) + return stmt isa Union{Core.GotoNode,Core.ReturnNode,Core.GotoIfNot} +end + +""" + offset_stmt!(dict::Dict, stmt, offset::Int, ir::Core.Compiler.IRCode, bb_map) + internal recursive function of [`extract_multiple_block_ir`](@ref) to shift SSAValue/Argument/BasicBlock in `ir` +""" +function offset_stmt!(dict::Dict, stmt, offset::Dict, ir::Core.Compiler.IRCode, bb_map) + if stmt isa Expr + Expr(stmt.head, (offset_stmt!(dict, a, offset, ir, bb_map) for a in stmt.args)...) + elseif stmt isa Core.Argument + tmp = Core.Argument(length(dict) + 2) + get!(dict, stmt, (tmp, ir.argtypes[stmt.n]))[1] + elseif stmt isa Core.ReturnNode + Core.ReturnNode(offset_stmt!(dict, stmt.val, offset, ir, bb_map)) + elseif stmt isa Core.SSAValue + stmt_bb = Core.Compiler.block_for_inst(ir, stmt.id) + if stmt_bb in keys(offset) #TODO: remove? && stmt.id > offset[stmt_bb] + Core.SSAValue(stmt.id - offset[stmt_bb]) + else + #the stmt is transformed to an IR argument + tmp = Core.Argument(length(dict) + 2) + get!(dict, stmt, (tmp, ir.stmts.type[stmt.id]))[1] + end + elseif stmt isa Core.GotoNode + Core.GotoNode(get(bb_map, stmt.label, 0)) + elseif stmt isa Core.GotoIfNot + Core.GotoIfNot( + offset_stmt!(dict, stmt.cond, offset, ir, bb_map), get(bb_map, stmt.dest, 0) + ) + elseif stmt isa Core.PhiNode + Core.PhiNode( + Int32[bb_map[edge] for edge in stmt.edges], + Any[offset_stmt!(dict, value, offset, ir, bb_map) for value in stmt.values], + ) + elseif stmt isa Core.PiNode + Core.PiNode(offset_stmt!(dict, stmt.val, offset, ir, bb_map), stmt.typ) + else + stmt + end +end + + +using Debugger +""" + extract_multiple_block_ir(ir, to_extract_set::Set, args::Dict, new_returns::Vector)::WipExtracting + Extract from `ir` a list of blocks `to_extract_set`, creating an new independant IR containing only these blocks. + All unlinked SSA are added to the `args` dictionnary and all values of `new_returns` are returned by the new IR. +""" +function extract_multiple_block_ir( + ir::Core.Compiler.IRCode, to_extract_set::Set{Int}, args::Dict, new_returns::Vector +)::WipExtracting + @assert isempty(ir.new_nodes.stmts) + to_extract = sort(collect(to_extract_set)) + + #for each extracted basic block, get the new offset. + #useful to deal with non-contiguous extraction because in this case, the offset doesn't follow `ir` block offset anymore + bb_offset::Dict{Int,Int} = Dict() + new_n_stmt = 0 + if !isempty(to_extract) + cumulative_offset = (first(ir.cfg.blocks[first(to_extract)].stmts)) - 1 + for bb in minimum(to_extract):maximum(to_extract) + n_stmt = length(ir.cfg.blocks[bb].stmts) + if bb in to_extract + bb_offset[bb] = cumulative_offset + new_n_stmt += n_stmt + else + cumulative_offset += n_stmt + end + end + end + + block_map = Dict() + for (i, b) in enumerate(to_extract) + block_map[b] = i + end + + cfg = new_cfg(ir, to_extract, block_map) + + #f = first(ir.cfg.blocks[first(to_extract)].stmts) + #l = last(ir.cfg.blocks[last(to_extract)].stmts) + + #PhiNode uses the global IR, either shift it or add it to the new IR argument + for (i, rb) in enumerate(new_returns) + rb isa Union{Core.SSAValue,Core.Argument} || continue + new_returns[i] = offset_stmt!(args, rb, bb_offset, ir, block_map) + end + + #recreate instruction_stream of the block + instruction_stream = Core.Compiler.InstructionStream(new_n_stmt) + dico = Dict() + new_stmt = 0 + for bb in to_extract + range_bb = ir.cfg.blocks[bb].stmts[[1, end]] + for old_stmt in range_bb[1]:range_bb[2] + new_stmt += 1 + Core.Compiler.setindex!(instruction_stream, ir.stmts[old_stmt], new_stmt) #TODO: check if needed + #ssa offset + instruction_stream.stmt[new_stmt] = offset_stmt!( + args, ir.stmts.stmt[old_stmt], bb_offset, ir, block_map + ) + #line_info + line_info = ir.stmts.line[old_stmt] + line_info == 0 && continue + instruction_stream.line[new_stmt] = get!(dico, line_info, length(dico) + 1) + end + end + + linetable = ir.linetable[sort(collect(keys(dico)))] + linetable = [ + Core.LineInfoNode(l.module, l.method, l.file, l.line, Int32(0)) for l in linetable + ] + #Build the new IR argtypes from args dictionnary + (_, argtypes) = vec_args(ir, args) + #JuliaIR block can end without a terminator + new_ir, has_terminator, n_ssa = if !isempty(instruction_stream) + (Core.Compiler.IRCode( + instruction_stream, + cfg, + linetable, + argtypes, + Expr[], + Core.Compiler.VarState[], + ), is_a_terminator(instruction_stream.stmt[end]), length(instruction_stream)) + else + new_ir = CC.IRCode() + empty!(new_ir.argtypes) + push!(new_ir.argtypes, argtypes...) + (new_ir, true, 1) + end + + @lk new_returns args + @error "" args argtypes + #Debugger.@bp + retu = if length(new_returns) > 1 + tuple = Core.Compiler.NewInstruction( + Expr(:call, Core.GlobalRef(Core, :tuple), new_returns...), + Tuple{type_from_ssa(new_ir, argtypes, new_returns)...}, + ) + Core.Compiler.insert_node!( + new_ir, Core.Compiler.SSAValue(n_ssa), tuple, !has_terminator + ) + else + length(new_returns) == 1 ? only(new_returns) : nothing + end + + if has_terminator + change_stmt!(new_ir, n_ssa, Core.ReturnNode(retu), Nothing) + else + terminator = Core.Compiler.NewInstruction(Core.ReturnNode(retu), Nothing) + @lk new_ir n_ssa terminator + Core.Compiler.insert_node!(new_ir, Core.Compiler.SSAValue(n_ssa), terminator, true) + end + return WipExtracting(Core.Compiler.compact!(new_ir, true)) +end + +function mlir_type(x) + return Reactant.MLIR.IR.TensorType( + size(x), Reactant.MLIR.IR.Type(Reactant.unwrapped_eltype(x)) + ) +end + +""" + vec_args(ir::Core.Compiler.IRCode, new_args::Dict)::Vector + Construct args Vector from `new_args` dictionnary +""" +function vec_args(ir::Core.Compiler.IRCode, new_args::Dict) + argtypes = Vector(undef, length(new_args) + 1) + argtypes[1] = Core.Const("opaque") + value = Vector(undef, length(new_args)) + for (arg, index) in new_args + value[index[1].n - 1] = arg + argtypes[index[1].n] = if arg isa Core.Argument #TODO: reuse function + index[2] + else + ir.stmts.type[arg.id] + end + end + return (value, argtypes) +end + +""" + typeof_ir(ir::CC.IRCode, e::Union{Core.Argument, Core.SSAValue}) + Return the type of a stmt in `ir` + TODO: replace by CC.argextype +""" +function typeof_ir(ir::CC.IRCode, e::Union{Core.Argument,Core.SSAValue}) + if e isa Core.Argument + ir.argtypes[e.n] + else + ir.stmts.type[e.id] + end +end + +""" + finish(wir::WipExtracting, new_args::Vector)::Code.Compiler.IRCode + + Constructing the extracted IR by applying the full arguments list +""" +function finish(wir::WipExtracting, new_args::Vector) + (; ir) = wir + empty!(ir.argtypes) + append!(ir.argtypes, new_args) + ir = rewrite_insts!(ir, current_interpreter[], false)[1] + return ir +end + +""" + add_phi_value!(v::Vector, phi::Core.PhiNode, edge::Set{Int}) + + Add `Core.PhiNode` values to `v` for each `edge` in the set. + If phi node contains `header_bb` and no element in `edge` then insert the value associated with the header. +""" +function add_phi_value!(v::Vector, phi::Core.PhiNode, edge::Set{Int}, header_bb::Int) + header_index = nothing + find_element = false + for (i, e) in enumerate(phi.edges) + e == header_bb && (header_index = i) + e in edge || continue + find_element = true + push!(v, phi.values[i]) #TODO: add break after + end + (!find_element && !isnothing(header_index)) && (push!(v, phi.values[header_index])) +end + +""" + cond_ssa(ir::CC.IRCode, bb::Int) + + Return the SSA value in a traced GotoIfNot +""" +function cond_ssa(ir::CC.IRCode, bb::Int) + ti = terminator_index(ir, bb) + terminator = ir.stmts.stmt[ti] + terminator isa Core.GotoIfNot || return nothing + protection = ir.stmts.stmt[terminator.cond.id] + ( + protection isa Expr && + protection.head == :call && + protection.args[1] == Core.GlobalRef(@__MODULE__, :traced_protection) + ) || return nothing + return protection.args[2] +end + +""" + check_integrity(ir::CC.IRCode)::Bool + check if `unreachable` is present in the IR, return true if none +""" +function check_integrity(ir::CC.IRCode)::Bool + return !any(ir.stmts.stmt .== [Core.ReturnNode()]) +end \ No newline at end of file diff --git a/src/auto_cf/debug_utils.jl b/src/auto_cf/debug_utils.jl new file mode 100644 index 0000000000..d8ad1863a4 --- /dev/null +++ b/src/auto_cf/debug_utils.jl @@ -0,0 +1,26 @@ +macro stop(n::Int) + u = :counter #gensym() + e = esc(u) + quote + isdefined(@__MODULE__, $(QuoteNode(u))) || global $e = $n + global $e + $e<2 && error("stop") + $e -= 1 + end +end + + + +#leak each argument to a global variable and store each instance of it +macro lks(args...) + nargs = [ Symbol(string(arg) * "s") for arg in args] + quote + $([:( + let val = $(esc(p)) + isdefined(@__MODULE__, $(QuoteNode(n))) || global $(esc(n)) = [] + global $(esc(n)) + push!($(esc(n)), val) + end + ) for (p,n) in zip(args, nargs)]...) + end +end \ No newline at end of file diff --git a/src/auto_cf/mlir_utils.jl b/src/auto_cf/mlir_utils.jl new file mode 100644 index 0000000000..059dc5f315 --- /dev/null +++ b/src/auto_cf/mlir_utils.jl @@ -0,0 +1,23 @@ +function change_value!(from::Reactant.MLIR.IR.Value, to::Reactant.MLIR.IR.Value, op::Reactant.MLIR.IR.Operation) + for i in 1:Reactant.MLIR.IR.noperands(op) + Reactant.MLIR.IR.operand(op, i) == from || continue + Reactant.MLIR.IR.operand!(op, i, to) + end + + for i in 1:Reactant.MLIR.IR.nregions(op) + r = Reactant.MLIR.IR.region(op, i) + change_value!(from, to, r) + end +end + +function change_value!(from::Reactant.MLIR.IR.Value, to::Reactant.MLIR.IR.Value, region::Reactant.MLIR.IR.Region) + for block in Reactant.MLIR.IR.BlockIterator(region) + change_value!(from, to, block) + end +end + +function change_value!(from::Reactant.MLIR.IR.Value, to::Reactant.MLIR.IR.Value, block::Reactant.MLIR.IR.Block) + for op in Reactant.MLIR.IR.OperationIterator(block) + change_value!(from, to, op) + end +end \ No newline at end of file diff --git a/src/auto_cf/new_inference.jl b/src/auto_cf/new_inference.jl new file mode 100644 index 0000000000..36ac735582 --- /dev/null +++ b/src/auto_cf/new_inference.jl @@ -0,0 +1,976 @@ +#simple version of `CC.scan_slot_def_use` +function fill_slot_definition_map(frame) + n_slot = length(frame.src.slotnames) + n_args = length(frame.linfo.specTypes.types) + v = [0 for _ in 1:n_slot] + for (i, stmt) in enumerate(frame.src.code) + stmt isa Expr || continue + stmt.head == :(=) || continue + slot = stmt.args[1] + slot isa Core.SlotNumber || continue + slot.id > n_args || continue + v[slot.id] = v[slot.id] == 0 ? i : v[slot.id] + end + return v +end + +function fill_slot_usage_map(frame) + n_slot = length(frame.src.slotnames) + v = [Set() for _ in 1:n_slot] + for (pos, stmt) in enumerate(frame.src.code) + get_slot(v, stmt, frame, pos) + end + return v +end + +function get_slot(vec, stmt, frame, pos) + if stmt isa Expr + stmt.head == :(=) && return get_slot(vec, stmt.args[2], frame, pos) + for e in stmt.args + get_slot(vec, e, frame, pos) + end + elseif stmt isa Core.SlotNumber + push!(vec[stmt.id], CC.block_for_inst(frame.cfg, pos)) + else + stmt + end +end + +#an = Analysis(Tree(nothing, [], Ref{Tree}()), nothing, nothing, nothing, nothing) + +function update_tree!(an::Analysis, bb::Int) + for c in an.tree.children + c.node.header_bb == bb || continue + an.pending_tree = c + return true + end + return false +end + +function add_tree!(an::Analysis, tl) + parent = an.tree + t = Tree(tl, [], Ref{Tree}(parent)) + push!(parent.children, t) + return an.pending_tree = t +end + +#Several TCF can end in the same bb +function up_tree!(an::Analysis, bb) + terminal = false + while is_terminal_bb(an.tree, bb) + an.tree = an.tree.parent[] + terminal = true + end + terminal && return nothing + #terminal bb is not always reach: for instance, if bodies are more precisely inferred and nothing change + while !isnothing(an.tree.node) + in_header(bb, an.tree.node) && break + an.tree = an.tree.parent[] + end +end + +function down_tree!(an::Analysis, bb) + for child in an.tree.children + if child.node.header_bb == bb + an.tree = child + break + end + end +end + +function is_terminal_bb(tree::Tree, bb) + isnothing(tree.node) && return false + return tree.node.terminal_bb == bb +end + +Base.in(bb::Int, is::IfStructure) = bb in is.true_bbs || bb in is.false_bbs +Base.in(bb::Int, is::ForStructure) = bb in is.body_bbs + +function in_header(bb::Int, is::IfStructure) + return bb in is.true_bbs || bb in is.false_bbs || bb == is.header_bb +end +function in_header(bb::Int, is::ForStructure) + return bb in is.body_bbs || bb == is.header_bb || bb == is.latch_bb +end + +function in_stack(tree::Tree, bb::Int) + while !isnothing(tree.node) + in_header(bb, tree.node) && return true + tree = tree.parent[] + end + return false +end + +#TODO: don't recompute TCF each time +function add_cf!(an, frame, currbb, currpc, condt) + update_tree!(an, frame.currbb) && return false + + tl = is_a_traced_loop(an, frame.src, frame.cfg, frame.currbb) + if tl !== nothing + add_tree!(an, tl) + return false + end + + tl = is_a_traced_if(an, frame, frame.currbb, condt) + if tl !== nothing + add_tree!(an, tl) + !tl.legalize[] || return false + #legalize if by inserting a call + goto_if_not_index = terminator_index(frame.cfg, frame.currbb) + cond = tl.ssa_cond + ssa = add_instruction!( + frame, + goto_if_not_index - 1, + Expr(:call, GlobalRef(@__MODULE__, :traced_protection), cond), + ) + invalidate_slot_definition_analysis!(an) + (; dest::Int) = frame.src.code[goto_if_not_index + 1]::Core.GotoIfNot #shifted because of the insertion + modify_instruction!(frame, goto_if_not_index + 1, Core.GotoIfNot(ssa, dest)) + tl.legalize[] = true + return true + end + return false +end + +@noinline traced_protection(x::Reactant.TracedRNumber{Bool}) = CC.inferencebarrier(x)::Bool +Reactant.@skip_rewrite_func traced_protection + +@noinline upgrade(x) = Reactant.Ops.constant(x) +@noinline upgrade(x::Union{Reactant.TracedRNumber,Reactant.TracedRArray}) = x + +Reactant.@skip_rewrite_func upgrade +#TODO: need a new traced mode Julia Type Non-concrete -> Traced +upgrade_traced_type(t::Core.Const) = upgrade_traced_type(CC.widenconst(t)) +upgrade_traced_type(t::Type{<:Number}) = Reactant.TracedRNumber{t} +upgrade_traced_type(t::Type{<:Reactant.TracedRNumber}) = t + +in_tcf(an::Analysis) = begin + !isnothing(an.tree.node) +end + +invalidate_slot_definition_analysis!(an) = an.slotanalysis = nothing + +function if_type_passing!(an, frame) + in_tcf(an) || return false + last_cf = an.tree.node + last_cf isa IfStructure || return false + last_cf.header_bb == frame.currbb || return false + !last_cf.legalize[] || return false + cond = last_cf.ssa_cond + goto_if_not_index = terminator_index(frame.cfg, frame.currbb) + ssa = add_instruction!( + frame, + goto_if_not_index - 1, + Expr(:call, GlobalRef(@__MODULE__, :traced_protection), cond), + ) + invalidate_slot_definition_analysis!(an) + + (; dest::Int) = frame.src.code[goto_if_not_index + 1]::Core.GotoIfNot #shifted because of the insertion + modify_instruction!(frame, goto_if_not_index + 1, Core.GotoIfNot(ssa, dest)) + last_cf.legalize[] = true + #update frame + return true +end + +function can_upgrade_loop(an, rt) + in_tcf(an) || return false + last_cf = an.tree.node + last_cf isa ForStructure || return false + last_cf.state == Maybe || return false + is_traced(rt) || return false + return true +end + +# a = expr +# => +# a = upgrade(expr) +# do nothing if expr is already an upgrade call +#TODO: rt suspicious +function apply_slot_upgrade!(frame, pos::Int, rt)::Bool + @warn "upgrade slot $pos $rt" + stmt = frame.src.code[pos] + @assert Base.isexpr(stmt, :(=)) "$stmt" + r = stmt.args[2] + #TODO: iterate can be upgraded to a traced iterate. SSAValue, slots & literal only need stmt change. Others need a new stmt + if Base.isexpr(r, :call) + r.args[1] == GlobalRef(@__MODULE__, :upgrade) && return false + if r.args[1] == GlobalRef(Base, :iterate) + new_type = traced_iterator(rt) + frame.src.code[pos] = Expr( + :(=), + stmt.args[1], + Expr(:call, GlobalRef(Base, :iterate), new_type, r.args[2:end]...), + ) + return true + end + frame.src.code[pos] = stmt.args[2] + add_instruction!( + frame, + pos, + Expr( + :(=), + stmt.args[1], + Expr(:call, GlobalRef(@__MODULE__, :upgrade), Core.SSAValue(pos)), + ); + next_bb=false, + ) + elseif r isa Core.SlotNumber || r isa Core.SSAValue || true #TODO: for expr we must create a new call and the expr + frame.src.code[pos] = Expr( + :(=), stmt.args[1], Expr(:call, GlobalRef(@__MODULE__, :upgrade), stmt.args[2]) + ) + else + error("unsupported slot upgrade $stmt") + end + return true +end + +function current_top_struct(tree) + top_struct = nothing + while !isnothing(tree.node) + top_struct = tree.node + tree = tree.parent[] + end + return top_struct +end + +function get_root(tree) + while !isnothing(tree.node) + tree = tree.parent[] + end + return tree +end + +#TODO: remove +function get_first_slot_read_stack(frame, tree, slot::Core.SlotNumber, stop::Int) + node = current_top_struct(tree) + start_stmt = frame.cfg.blocks[node.header_bb].stmts.start + for stmt_index in start_stmt:stop + s = frame.src.code[stmt_index] + s isa Core.SlotNumber || continue + s.id == slot.id && return CC.block_for_inst(frame.cfg.index, stmt_index) + end + return nothing +end + +@inline function check_and_upgrade_slot!(an, frame, stmt, rt, currstate) + in_tcf(an) || return (NoUpgrade,) + stmt isa Expr || return (NoUpgrade,) + stmt.head == :(=) || return (NoUpgrade,) + last_cf = an.tree.node + rt_traced = is_traced(rt) + slot = stmt.args[1].id + slot_type::Type = CC.widenconst(currstate[slot].typ) + + #If the stmt is traced: if the slot is traced or not set, don't need to upgrade the slot + #TODO: Nothing suspicions + rt_traced && + (is_traced(slot_type) || slot_type === Union{} || slot_type == Nothing) && + return (NoUpgrade,) + + if last_cf isa IfStructure + (frame.currbb in last_cf.true_bbs || frame.currbb in last_cf.false_bbs) || + return (NoUpgrade,) + #inside a traced_if, slot must be upgraded to a traced type + sa = get_slot_analysis(an, frame)::SlotAnalysis + #TODO: approximation: use liveness analysis to precise promote local slot + # if traced + isempty( + setdiff(sa.slot_bb_usage[slot], union(last_cf.true_bbs, last_cf.false_bbs)) + ) && return (NoUpgrade,) + + #invalidate_slot_definition_analysis!(an) + return if apply_slot_upgrade!(frame, frame.currpc, rt) + (UpgradeLocally,) + else + (NoUpgrade,) + end + + #no need to change frame furthermore + elseif last_cf isa ForStructure + (last_cf.state == Traced || last_cf.state == Upgraded) || return (NoUpgrade,) + if (!rt_traced && is_traced(slot_type)) + return if apply_slot_upgrade!(frame, frame.currpc, rt) + (UpgradeLocally,) + else + (NoUpgrade,) + end + end + sa = get_slot_analysis(an, frame)::SlotAnalysis + slot_definition_pos = sa.slot_stmt_def[slot] + slot_definition_bb = CC.block_for_inst(frame.cfg, slot_definition_pos) + #local slot doesn't need to be upgrade TODO: suspicious + slot_definition_bb in last_cf.body_bbs && return (NoUpgrade,) + if slot_definition_bb == last_cf.header_bb || in_stack(an.tree, slot_definition_bb) + #stack upgrade + #the slot has been upgraded: find read of the slot inside the current traced stack: if any, we must restart the inference from there + return if apply_slot_upgrade!(frame, slot_definition_pos, rt) + (UpgradeDefinition, stmt.args[1]) + else + (NoUpgrade,) + end + else + #global upgrade: add a new slot + new_slot_def_pos = if last_cf.header_bb == 1 + #first block contains argument to slot write: new instructions must be placed after (otherwise all the IR is dead) + new_index = 0 + for i in frame.cfg.blocks[1].stmts + local_stmt = frame.src.code[i] + local_stmt isa Expr && + local_stmt.head == :(=) && + typeof.(frame.src.code[i].args) == + [Core.SlotNumber, Core.SlotNumber] && + continue + local_stmt isa Core.NewvarNode && continue + new_index = i + break + end + new_index + else + frame.cfg.blocks[last_cf.header_bb].stmts.start - 1 + end + #add_slot_change!(frame.src, new_slot_def_pos, slot) + slot = stmt.args[1] + #CodeInfo: Cannot use a slot inside a call + add_instruction!(frame, new_slot_def_pos, slot) + add_instruction!( + frame, + new_slot_def_pos + 1, + Expr( + :(=), + slot, + Expr( + :call, + GlobalRef(@__MODULE__, :upgrade), + Core.SSAValue(new_slot_def_pos + 1), + ), + ), + ) + invalidate_slot_definition_analysis!(an) + return (UpgradeDefinitionGlobal,) + end + return (UpgradeDefinition,) + end +end + +terminator_index(ir::Core.Compiler.IRCode, bb::Int) = terminator_index(ir.cfg, bb) +terminator_index(cfg::CC.CFG, bb::Int) = cfg.blocks[bb].stmts.stop +start_index(ir::CC.IRCode, bb::Int) = start_index(ir.cfg, bb) +start_index(cfg::CC.CFG, bb::Int) = bb == 1 ? 1 : cfg.index[bb - 1] + +#TODO: proper support this by walking the IR +function is_traced_loop_iterator(src::CC.CodeInfo, cfg::CC.CFG, bb::Int) + terminator_pos = terminator_index(cfg, bb) + iterator_index = src.code[terminator_pos].cond.id - 3 + iterator_type = src.ssavaluetypes[iterator_index] + return is_traced(iterator_type) +end + +is_traced(t::Type) = parentmodule(t) == Reactant +is_traced(::Core.TypeofBottom) = false +is_traced(t::UnionAll) = is_traced(CC.unwrap_unionall(t)) +is_traced(u::Union) = (|)(is_traced.(Base.uniontypes(u))...) +function is_traced(t::Type{<:Tuple}) + t isa Union && return @invoke is_traced(t::Union) + t = Base.unwrap_unionall(t) + t isa UnionAll && return is_traced(Base.unwrap_unionall(t)) + if typeof(t) == UnionAll + t = t.body + end + return (|)(is_traced.(t.types)...) +end +is_traced(::Type{Tuple{}}) = false +is_traced(t) = false + +#TODO: add support to while loop / general loop +function is_a_traced_loop(an, src::CC.CodeInfo, cfg::CC.CFG, bb_header) + bb_body_first = min(cfg.blocks[bb_header].succs...) + preds::Vector{Int} = cfg.blocks[bb_body_first].preds + (max(preds...) < bb_body_first) && return nothing #No loop + bb_latch = max(preds...) + bb_end = max(cfg.blocks[bb_header].succs...) + bb_body_last = only(cfg.blocks[bb_latch].preds) + #TODO: proper accu and block + return ForStructure( + (), + bb_header, + bb_latch, + bb_end, + Set(bb_body_first:bb_body_last), + is_traced_loop_iterator(src, cfg, bb_header) ? Traced : Maybe, + ) +end + +function bb_owned_branch(domtree, bb::Int)::Set{Int} + bbs = Set(bb) + for c in domtree[bb].children + bbs = union(bbs, bb_owned_branch(domtree, c)) + end + return bbs +end + +function bb_branch(cfg, bb::Int, t_bb::Int)::Set{Int} + bbs = Set() + work = [bb] + while !isempty(work) + c_bb = pop!(work) + (c_bb in bbs || c_bb == t_bb) && continue + push!(bbs, c_bb) + for s in cfg.blocks[c_bb].succs + push!(work, s) + end + end + return bbs +end + +function get_doms(an, frame) + if an.domtree === nothing + an.domtree = CC.construct_domtree(frame.cfg).nodes + an.postdomtree = CC.construct_postdomtree(frame.cfg).nodes + end + return (an.domtree, an.postdomtree) +end + +function get_slot_analysis(an::Analysis, frame)::SlotAnalysis + if an.slotanalysis === nothing + an.slotanalysis = SlotAnalysis( + fill_slot_definition_map(frame), fill_slot_usage_map(frame) + ) + end + return an.slotanalysis +end + +#TODO:remove currbb +function is_a_traced_if(an, frame, currbb, condt) + condt == Reactant.TracedRNumber{Bool} || return nothing + (domtree, postdomtree) = get_doms(an, frame) #compute dominance analysis only when needed + bb = frame.cfg.blocks[currbb] + succs::Vector{Int64} = bb.succs + if_goto_stmt::Core.GotoIfNot = frame.src.code[last(bb.stmts)] + #CodeInfo GotoIfNot.dest is a stmt + first_false_bb = CC.block_for_inst(frame.cfg.index, if_goto_stmt.dest) + first_true_bb = succs[1] == first_false_bb ? succs[2] : succs[1] + last_child = last(domtree[currbb].children) + is_diamond = currbb in postdomtree[last_child].children + final_bb = if is_diamond + last_child + else + if_final_bb = nothing + for (final_bb, nodes) in enumerate(postdomtree) + if currbb in nodes.children + if_final_bb = final_bb + break + end + end + @assert !isnothing(if_final_bb) + if_final_bb + end + true_bbs = bb_branch(frame.cfg, first_true_bb, final_bb) + false_bbs = bb_branch(frame.cfg, first_false_bb, final_bb) + all_owned = bb_owned_branch(domtree, currbb) + true_owned_bbs = intersect(bb_owned_branch(domtree, first_true_bb), all_owned, true_bbs) + false_owned_bbs = intersect( + bb_owned_branch(domtree, first_false_bb), all_owned, false_bbs + ) + return IfStructure( + if_goto_stmt.cond, + currbb, + final_bb, + true_bbs, + false_bbs, + true_owned_bbs, + false_owned_bbs, + Ref{Bool}(false), + Set(), + ) +end + +#HACK: add a general T to Traced{T} conversion +function traced_iterator(::Type{Union{Nothing,Tuple{T,T}}}) where {T} + is_traced(T) && return T + Tout = Reactant.TracedRNumber{T} + return Union{Nothing,Tuple{Tout,Nothing}} #TODO: replace INT -> Nothing +end + +traced_iterator(t::Type{Tuple{T,T}}) where {T} = traced_iterator(Union{Nothing,t}) + +traced_iterator(t) = begin + if !is_traced(t) + error("fallback $t") + end + t +end + +function get_new_iterator_type(src::CC.CodeInfo, cfg::CC.CFG, bb::Int) + terminator_pos = terminator_index(cfg, bb) + iterator_index = src.code[terminator_pos].cond.id - 3 + iterator_type = src.ssavaluetypes[iterator_index] + iterator_type = CC.widenconst(iterator_type) + return traced_iterator(iterator_type) +end + +#TODO: proper check if the iterator exists and replace -3 +function rewrite_iterator(src::CC.CodeInfo, cfg::CC.CFG, bb::Int, new_type::Type) + terminator_pos = terminator_index(cfg, bb) + iterator_index = src.code[terminator_pos].cond.id - 3 + iterator = src.code[iterator_index] + iterator_arg = iterator.args[end].args[end] + iterator.args[end] = Expr(:call, GlobalRef(Base, :iterate), new_type, iterator_arg) + return iterator.args[1] +end + +function reset_slot!(state::Union{Nothing,Vector{Core.Compiler.VarState}}, slot::Int) + return isnothing(state) ? state : state[slot] = CC.VarState(Union{}, true) +end + +function reset_slot!( + state::Union{Nothing,Vector{Core.Compiler.VarState}}, slot::Core.SlotNumber +) + return reset_slot!(state, slot.id) +end + +function reset_slot!(states) + for i in eachindex(states) + states[i] = nothing + end +end + +function reset_slot!(states, fs::ForStructure, slot::Core.SlotNumber) + reset_slot!(states[fs.header_bb], slot) + for bb in fs.body_bbs + reset_slot!(states[bb], slot) + end + reset_slot!(states[fs.latch_bb], slot) + return reset_slot!(states[fs.terminal_bb], slot) +end + +#TODO: stack -> branch +function rewrite_loop_stack!(an::Analysis, frame, states, currstate) + (; src::CC.CodeInfo, cfg::CC.CFG) = frame + ct = an.tree + top_loop_tcf = nothing + while !isnothing(ct.node) + node = ct.node + ct = ct.parent[] + node isa ForStructure || continue + node.state == Maybe || continue + #TODO: while loop + new_iterator_type = get_new_iterator_type(src, cfg, node.header_bb) + slot = rewrite_iterator(src, cfg, node.header_bb, new_iterator_type) + last_for_bb = last(sort(collect(node.body_bbs))) + slot = rewrite_iterator(frame.src, frame.cfg, last_for_bb, new_iterator_type) + top_loop_tcf = ct + node.state = Upgraded + end + @assert(!isnothing(top_loop_tcf)) + return top_loop_tcf + #restart type inference from: top_header_rewritten +end + +#Transform an n-terminator bb IR to an 1-terminator bb IR +#TODO: improve algo: remove frame in the loop +function normalize_exit!(frame) + terminator_bbs = findall(isempty.(getfield.(frame.cfg.blocks, :succs))) + length(terminator_bbs) <= 1 && return nothing + new_slot = create_slot!(frame) + add_instruction!(frame, 0, Core.NewvarNode(new_slot)) + + n = length(frame.src.code) + add_instruction!(frame, n, new_slot) + add_instruction!(frame, n + 1, Core.ReturnNode(Core.SSAValue(n + 1))) + push!(frame.bb_vartables, nothing) + offset = 0 + tis = [terminator_index(frame.cfg, tbb) for tbb in terminator_bbs] + for tbb in tis + return_index = offset + tbb + return_ = frame.src.code[return_index] + @assert(return_ isa Core.ReturnNode) + exit_bb_start_pos = terminator_index(frame.cfg, length(frame.cfg.blocks)) + offset += if return_.val isa Core.SSAValue + temp = frame.src.code[return_.val.id] + frame.src.code[return_.val.id] = Expr(:(=), new_slot, temp) + frame.src.code[return_index] = Core.GotoNode(exit_bb_start_pos) + 0 + else + add_instruction!( + frame, return_index, Core.GotoNode(exit_bb_start_pos); next_bb=false + ) + frame.src.code[return_index] = Expr(:(=), new_slot, return_.val) + 1 + end + end + return frame.cfg = CC.compute_basic_blocks(frame.src.code) +end + +#= + CC.typeinf_local(interp::Reactant.ReactantInterp, frame::CC.InferenceState) + + Specialize type inference to support control flow aware tracing type inferency + TODO: enable this only for usercode because the new type inference is costly now (several type inference can be needed for a same function) +=# +function CC.typeinf_local(interp::Reactant.ReactantInterp, frame::CC.InferenceState) + mod = frame.mod + if @static (VERSION < v"1.12" && VERSION > v"1.11") && + has_ancestor(mod, Main) && + is_traced(frame.linfo.specTypes) && + !has_ancestor(mod, Core) && + !has_ancestor(mod, Base) && + !has_ancestor(mod, Reactant) + @info "auto control flow tracing enabled: $(frame.linfo)" + normalize_exit!(frame) + an = Analysis(Tree(nothing, [], Ref{Tree}()), nothing, nothing, nothing, nothing) + typeinf_local_traced(interp, frame, an) + @error frame.src + isempty(an.tree) || + (get_meta(interp).traced_tree_map[mi_key(frame.linfo)] = an.tree) + else + @invoke CC.typeinf_local(interp::CC.AbstractInterpreter, frame::CC.InferenceState) + end +end + +function update_context!(an::Analysis, currbb::Int) + isnothing(an.pending_tree) && return nothing + currbb in an.pending_tree.node || return nothing + an.tree = an.pending_tree + return an.pending_tree = nothing +end + +function handle_different_branches() end + +#= + typeinf_local_traced(interp::ReactantInterp, frame::CC.InferenceState) + + type infer the `frame` using a Reactant interpreter; notably detect traced control-flow and upgrade traced slot +=# +function typeinf_local_traced( + interp::Reactant.ReactantInterp, frame::CC.InferenceState, an::Analysis +) + @assert !CC.is_inferred(frame) + frame.dont_work_on_me = true # mark that this function is currently on the stack + W = frame.ip + ssavaluetypes = frame.ssavaluetypes + bbs = frame.cfg.blocks + nbbs = length(bbs) + 𝕃ᡒ = CC.typeinf_lattice(interp) + + currbb = frame.currbb + if currbb != 1 + currbb = frame.currbb = CC._bits_findnext(W.bits, 1)::Int # next basic block + end + + states = frame.bb_vartables + init_state = CC.copy(states[currbb]) + currstate = CC.copy(states[currbb]::CC.VarTable) + terminal_block_if::Union{Nothing,IfStructure} = nothing + + while currbb <= nbbs + CC.delete!(W, currbb) + bbstart = first(bbs[currbb].stmts) + bbend = last(bbs[currbb].stmts) + currpc = bbstart - 1 + terminal_block_if = + if an.tree.node isa IfStructure && is_terminal_bb(an.tree, currbb) + an.tree.node + else + nothing + end + update_context!(an, currbb) + up_tree!(an, currbb) + @warn frame.linfo currbb an.tree.node get_root(an.tree) + while currpc < bbend + currpc += 1 + frame.currpc = currpc + CC.empty_backedges!(frame, currpc) + stmt = frame.src.code[currpc] + # If we're at the end of the basic block ... + if currpc == bbend + # Handle control flow + if isa(stmt, Core.GotoNode) + succs = bbs[currbb].succs + @assert length(succs) == 1 + nextbb = succs[1] + ssavaluetypes[currpc] = Any + CC.handle_control_backedge!(interp, frame, currpc, stmt.label) + CC.add_curr_ssaflag!(frame, CC.IR_FLAG_NOTHROW) + @goto branch + elseif isa(stmt, Core.GotoIfNot) + condx = stmt.cond + condxslot = CC.ssa_def_slot(condx, frame) + condt = CC.abstract_eval_value(interp, condx, currstate, frame) + + if add_cf!(an, frame, currbb, currpc, condt) + @goto reset_inference + end + + if condt === CC.Bottom + ssavaluetypes[currpc] = CC.Bottom + CC.empty!(frame.pclimitations) + @goto find_next_bb + end + orig_condt = condt + if !(isa(condt, Core.Const) || isa(condt, CC.Conditional)) && + isa(condxslot, Core.SlotNumber) + # if this non-`Conditional` object is a slot, we form and propagate + # the conditional constraint on it + condt = CC.Conditional( + condxslot, Core.Const(true), Core.Const(false) + ) + end + condval = CC.maybe_extract_const_bool(condt) + nothrow = (condval !== nothing) || CC.:(βŠ‘)(𝕃ᡒ, orig_condt, Bool) + if nothrow + CC.add_curr_ssaflag!(frame, CC.IR_FLAG_NOTHROW) + else + CC.update_exc_bestguess!(interp, TypeError, frame) + CC.propagate_to_error_handler!(currstate, frame, 𝕃ᡒ) + CC.merge_effects!(interp, frame, CC.EFFECTS_THROWS) + end + + if !CC.isempty(frame.pclimitations) + # we can't model the possible effect of control + # dependencies on the return + # directly to all the return values (unless we error first) + condval isa Bool || + CC.union!(frame.limitations, frame.pclimitations) + empty!(frame.pclimitations) + end + ssavaluetypes[currpc] = Any + if condval === true + @goto fallthrough + else + if !nothrow && !CC.hasintersect(CC.widenconst(orig_condt), Bool) + ssavaluetypes[currpc] = CC.Bottom + @goto find_next_bb + end + + succs = bbs[currbb].succs + if length(succs) == 1 + @assert condval === false || (stmt.dest === currpc + 1) + nextbb = succs[1] + @goto branch + end + @assert length(succs) == 2 + truebb = currbb + 1 + falsebb = succs[1] == truebb ? succs[2] : succs[1] + if condval === false + nextbb = falsebb + CC.handle_control_backedge!(interp, frame, currpc, stmt.dest) + @goto branch + end + # We continue with the true branch, but process the false + # branch here. + if isa(condt, CC.Conditional) + else_change = CC.conditional_change( + 𝕃ᡒ, currstate, condt.elsetype, condt.slot + ) + if else_change !== nothing + false_vartable = CC.stoverwrite1!( + copy(currstate), else_change + ) + else + false_vartable = currstate + end + changed = CC.update_bbstate!(𝕃ᡒ, frame, falsebb, false_vartable) + then_change = CC.conditional_change( + 𝕃ᡒ, currstate, condt.thentype, condt.slot + ) + + if then_change !== nothing + CC.stoverwrite1!(currstate, then_change) + end + else + changed = CC.update_bbstate!(𝕃ᡒ, frame, falsebb, currstate) + end + if changed + CC.handle_control_backedge!(interp, frame, currpc, stmt.dest) + CC.push!(W, falsebb) + end + @goto fallthrough + end + elseif isa(stmt, Core.ReturnNode) + rt = CC.abstract_eval_value(interp, stmt.val, currstate, frame) + if CC.update_bestguess!(interp, frame, currstate, rt) + CC.update_cycle_worklists!( + frame + ) do caller::CC.InferenceState, caller_pc::Int + # no reason to revisit if that call-site doesn't affect the final result + return caller.ssavaluetypes[caller_pc] !== Any + end + end + ssavaluetypes[frame.currpc] = Any + @goto find_next_bb + elseif isa(stmt, Core.EnterNode) + ssavaluetypes[currpc] = Any + CC.add_curr_ssaflag!(frame, CC.IR_FLAG_NOTHROW) + if isdefined(stmt, :scope) + scopet = CC.abstract_eval_value( + interp, stmt.scope, currstate, frame + ) + handler = frame.handlers[frame.handler_at[frame.currpc + 1][1]] + @assert handler.scopet !== nothing + if !CC.:(βŠ‘)(𝕃ᡒ, scopet, handler.scopet) + handler.scopet = CC.tmerge(𝕃ᡒ, scopet, handler.scopet) + if isdefined(handler, :scope_uses) + for bb in handler.scope_uses + push!(W, bb) + end + end + end + end + @goto fallthrough + elseif CC.isexpr(stmt, :leave) + ssavaluetypes[currpc] = Any + @goto fallthrough + end + # Fall through terminator - treat as regular stmt + end + + # Process non control-flow statements + (; changes, rt, exct) = CC.abstract_eval_basic_statement( + interp, stmt, currstate, frame + ) + if !CC.has_curr_ssaflag(frame, CC.IR_FLAG_NOTHROW) + if exct !== Union{} + CC.update_exc_bestguess!(interp, exct, frame) + # TODO: assert that these conditions match. For now, we assume the `nothrow` flag + # to be correct, but allow the exct to be an over-approximation. + end + CC.propagate_to_error_handler!(currstate, frame, 𝕃ᡒ) + end + + if !isnothing(terminal_block_if) + #check if both branches handle correctly the slot, otherwise an upgrade call must be inserted before the if + # if b c = 10 end => c is present only in one branch + if !(stmt isa Core.SlotNumber) + #consider only if expression slots. + terminal_block_if = nothing + else + if rt isa Union + union_types = Base.uniontypes(rt) + if length(union_types) > 1 + #TODO: proper check: must change an union only in case like this: Union{Int, Traced{Int}} + if length(is_traced.(union_types)) == 1 + stmt in terminal_block_if.unbalanced_slots && + error("cannot support $stmt : $rt") + push!(terminal_block_if.unbalanced_slots, stmt) + if_head_block = terminal_block_if.header_bb + terminator_pos = frame.cfg.blocks[if_head_block].stmts.stop + upgrade_pos = terminator_pos - 2 + add_instruction!(frame, upgrade_pos, stmt) + @error "missing slot in if branch: upgrade $stmt : $rt" + @error frame.src + add_instruction!( + frame, + upgrade_pos + 1, + Expr( + :(=), + stmt, + Expr( + :call, + GlobalRef(@__MODULE__, :upgrade), + Core.SSAValue(upgrade_pos + 1), + ), + ), + ) + invalidate_slot_definition_analysis!(an) + @goto reset_inference + end + end + end + end + end + + #upgrade maybe for loop here: eagerly restart type inference if we detect an traced type + #NOTE: must be placed before CC.Bottom check: in a traced context, an iterator with invalid arguments still should be upgraded + #@info stmt an.tree + upgrade_result = check_and_upgrade_slot!(an, frame, stmt, rt, currstate) + slot_state = first(upgrade_result) + if slot_state === UpgradeDefinition #Slot Upgrade ... + bbs = frame.cfg.blocks + bbend = last(bbs[currbb].stmts) + @goto reset_inference + continue + elseif slot_state === UpgradeDefinitionGlobal + @goto reset_inference + elseif slot_state === UpgradeLocally + @goto reset_inference + end + + if rt === CC.Bottom + ssavaluetypes[currpc] = CC.Bottom + # Special case: Bottom-typed PhiNodes do not error (but must also be unused) + if isa(stmt, Core.PhiNode) + continue + end + @goto find_next_bb + end + + #Slot upgrade must be placed before any slot/ssa table change + if changes !== nothing + CC.stoverwrite1!(currstate, changes) + end + if rt === nothing + ssavaluetypes[currpc] = Any + continue + end + + if can_upgrade_loop(an, rt) + rewrite_loop_stack!(an, frame, states, currstate) + @goto reset_inference + end + + # IMPORTANT: set the type + CC.record_ssa_assign!(𝕃ᡒ, currpc, rt, frame) + end # while currpc < bbend + + # Case 1: Fallthrough termination + begin + @label fallthrough + nextbb = currbb + 1 + end + + # Case 2: Directly branch to a different BB + begin + @label branch + if CC.update_bbstate!(𝕃ᡒ, frame, nextbb, currstate) + CC.push!(W, nextbb) + end + end + + # Case 3: Control flow ended along the current path (converged, return or throw) + begin + @label find_next_bb + currbb = frame.currbb = CC._bits_findnext(W.bits, 1)::Int # next basic block + currbb == -1 && break # the working set is empty + currbb > nbbs && break + nexttable = states[currbb] + if nexttable === nothing + CC.init_vartable!(currstate, frame) + else + CC.stoverwrite!(currstate, nexttable) + end + end + + begin + continue + @label reset_inference + CC.empty!(W) + currbb = 1 + frame.currbb = 1 + currpc = 1 + frame.currpc = 1 + reset_slot!(states) + an.tree = get_root(an.tree) + states[1] = copy(init_state) + currstate = copy(init_state) + for i in eachindex(frame.ssavaluetypes) + frame.ssavaluetypes[i] = CC.NotFound() + end + bbs = frame.cfg.blocks + nbbs = length(bbs) + ssavaluetypes = frame.ssavaluetypes + end + end # while currbb <= nbbs + @lk an + frame.dont_work_on_me = false + return nothing +end diff --git a/src/auto_cf/utils_bench.jl b/src/auto_cf/utils_bench.jl new file mode 100644 index 0000000000..ec0bd9bd58 --- /dev/null +++ b/src/auto_cf/utils_bench.jl @@ -0,0 +1,156 @@ +function init_mlir() + ctx = Reactant.MLIR.IR.Context() + @ccall Reactant.MLIR.API.mlir_c.RegisterDialects(ctx::Reactant.MLIR.API.MlirContext)::Cvoid +end + +get_traced_object(::Type{Reactant.TracedRNumber{T}}) where T = Reactant.Ops.constant(rand(T)) + +get_traced_object(::Type{Reactant.TracedRArray{T,N}}) where {T,N} = Reactant.Ops.constant(rand(T, [1 for i in 1:N]...)) + +get_traced_object(t) = begin + @error t + rand(t) +end + + +#= + analysis_reassign_block_id!(an::Analysis, ir::Core.IRCode, src::Core.CodeInfo) + slot2reg can change type infered CodeInfo CFG by removing non-reachable block, + ControlFlow analysis use blocks information and must be shifted + +=# +function analysis_reassign_block_id!(an::Analysis, ir::CC.IRCode, src::CC.CodeInfo) + cfg = CC.compute_basic_blocks(src.code) + length(ir.cfg.blocks) == length(cfg.blocks) && return false + @info "rewrite analysis blocks" + new_block_map = [] + i = 0 + for block in cfg.blocks + unreacheable_block = all(x->src.ssavaluetypes[x] === Union{}, block.stmts) + i = unreacheable_block ? i : i + 1 + push!(new_block_map, i) + end + @info new_block_map + function reassign_tree!(s::Set{Int}) + n = [new_block_map[i] for i in s] + empty!(s) + push!(s, n...) + end + + function reassign_tree!(is::IfStructure) + is.header_bb = new_block_map[is.header_bb] + is.terminal_bb = new_block_map[is.terminal_bb] + reassign_tree!(is.true_bbs) + reassign_tree!(is.false_bbs) + reassign_tree!(is.owned_true_bbs) + reassign_tree!(is.owned_false_bbs) + end + + function reassign_tree!(fs::ForStructure) + fs.header_bb = new_block_map[fs.header_bb] + fs.latch_bb = new_block_map[fs.latch_bb] + fs.terminal_bb = new_block_map[fs.terminal_bb] + reassign_tree!(fs.body_bbs) + end + + function reassign_tree!(t::Tree) + isnothing(t.node) || reassign_tree!(t.node) + for c in t.children + reassign_tree!(c) + end + end + reassign_tree!(an.tree) + @error an.tree + return true +end + +function test(f) + m = methods(f)[1] + types = m.sig.parameters[2:end] + mi = Base.method_instance(f, types) + @lk mi + world = Base.get_world_counter() + interp = Reactant.ReactantInterpreter(; world) + resul = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) + src = CC.retrieve_code_info(resul.linfo, world) + osrc = CC.copy(src) + @lk osrc src + frame = CC.InferenceState(resul, src, :no, interp) + CC.typeinf(interp, frame) + opt = CC.OptimizationState(frame, interp) + ir0 = CC.convert_to_ircode(opt.src, opt) + ir = CC.slot2reg(ir0, opt.src, opt) + analysis_reassign_block_id!(an, ir, src) + ir = CC.compact!(ir) + bir = CC.copy(ir) + @lk bir + ir_final = control_flow_transform!(an, ir) + + modu = Reactant.MLIR.IR.Module() + @lk modu + #init_caches() + Reactant.MLIR.IR.activate!(modu) + Reactant.MLIR.IR.activate!(Reactant.MLIR.IR.body(modu)) + + ttypes = collect(types)[is_traced.(types)] + @lk types ttypes + + + to_mlir(::Type{Reactant.TracedRArray{T,N}}) where {T,N} = Reactant.MLIR.IR.TensorType(repeat([4096], N), Reactant.MLIR.IR.Type(T)) + to_mlir(x) = Reactant.Ops.mlir_type(x) + f_args = to_mlir.(ttypes) + + temporal_func = Reactant.MLIR.Dialects.func.func_(; + sym_name="main_", + function_type=Reactant.MLIR.IR.FunctionType(f_args, []), + body=Reactant.MLIR.IR.Region(), + sym_visibility=Reactant.MLIR.IR.Attribute("private"), + ) + + main = Reactant.MLIR.IR.Block(f_args, [Reactant.MLIR.IR.Location() for _ in f_args]) + push!(Reactant.MLIR.IR.region(temporal_func, 1), main) + Reactant.Ops.activate_constant_context!(main) + Reactant.MLIR.IR.activate!(main) + + args = [] + i = 1 + for tt in types + if !is_traced(tt) + push!(args, rand(tt)) + continue + end + + arg = if ttypes[i] <: Reactant.TracedRArray + ttypes[i]((), nothing, repeat([4096], ttypes[i].parameters[2])) + else + ttypes[i]((), nothing) + end + Reactant.TracedUtils.set_mlir_data!(arg, Reactant.MLIR.IR.argument(main, i)) + push!(args, arg) + i += 1 + end + + + #A = Reactant.Ops.constant(rand(Int,2,2)); + #B = Reactant.Ops.constant(rand(Int,2,2)); + r = juliair_to_mlir(ir_final, args...)[2] + Reactant.Ops.return_(r...) + Reactant.Ops.deactivate_constant_context!(main) + Reactant.MLIR.IR.deactivate!(main) + + + func = Reactant.MLIR.Dialects.func.func_(; + sym_name="main", + function_type=Reactant.MLIR.IR.FunctionType(f_args, Reactant.MLIR.IR.Type[Reactant.Ops.mlir_type.(r)...]), + body=Reactant.MLIR.IR.Region(), + sym_visibility=Reactant.MLIR.IR.Attribute("private"), + ) + + Reactant.MLIR.API.mlirRegionTakeBody( + Reactant.MLIR.IR.region(func, 1), Reactant.MLIR.IR.region(temporal_func, 1)) + + Reactant.MLIR.API.mlirOperationDestroy(temporal_func.operation) + + Reactant.MLIR.IR.verifyall(Reactant.MLIR.IR.Operation(modu); debug=true) ||Β error("fail") + modu +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 8688603ff1..23420eb1c2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -339,6 +339,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) # Even if type unstable we do not want (or need) to replace intrinsic # calls or builtins with our version. ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) + ft == Core.OpaqueClosure && return false, inst, RT if ft == typeof(Core.kwcall) ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) end @@ -567,6 +568,15 @@ function safe_print(name, x) return ccall(:jl_, Cvoid, (Any,), name * " " * string(x)) end + +ReactantInterp = Enzyme.Compiler.Interpreter.EnzymeInterpreter{ + typeof(Reactant.set_reactant_abi) +} +include("auto_cf/analysis.jl") +include("auto_cf/AutoCF.jl") + + + const DEBUG_INTERP = Ref(false) # Rewrite type unstable calls to recurse into call_with_reactant to ensure @@ -678,7 +688,8 @@ function call_with_reactant_generator( end interp = ReactantInterpreter(; world) - + current_interpreter[] = interp + min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) @@ -716,7 +727,21 @@ function call_with_reactant_generator( ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) rt = CC.widenconst(CC.ignorelimited(result.result)) else - ir, rt = CC.typeinf_ircode(interp, mi, nothing) + result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) + @warn mi + frame = CC.InferenceState(result, :no, interp) + @assert !isnothing(frame) + CC.typeinf(interp, frame) + opt = CC.OptimizationState(frame, interp) + tree = get(get_meta(interp).traced_tree_map, mi_key(opt.linfo), nothing) + @warn opt.linfo tree + CC.@timeit "optimizer" ir = if isnothing(tree) || isempty(tree) + CC.run_passes_ipo_safe(opt.src, opt, result) + else + run_passes_ipo_safe_auto_cf(opt.src, opt, result, tree) + end + CC.ipo_dataflow_analysis!(interp, ir, result) + rt = CC.widenconst(CC.ignorelimited(result.result)) end if guaranteed_error @@ -954,3 +979,4 @@ nmantissa(::Type{Float32}) = 23 nmantissa(::Type{Float64}) = 52 _unwrap_val(::Val{T}) where {T} = T + diff --git a/test/auto_cf/basic.jl b/test/auto_cf/basic.jl new file mode 100644 index 0000000000..75fca16d36 --- /dev/null +++ b/test/auto_cf/basic.jl @@ -0,0 +1,236 @@ +using Test + +function promote_loop(a) + for j in 1:2 + a += 5 + end + a +end + +function promote_loop_twin(a) + for j in 1:2 + a += 5 + end + + for j in 1:2 + a += 5 + end + a +end + +function promote_loop_slot_write(a) + for j in 1:2 + a += 5 + j + end + a +end + +function double_promote_loop_slot_write(a) + for i in 1:5 + for j in 1:2 + a += i + j + end + end + a +end + +function simple_promote_loop_mutable(A) + for i in 1:10 + A[i, 1] = A[1, i] + end + A +end + +function simple_promote_loop_mutable_repeated(A) + for j in 1:10 + A[2, 1] = A[1, 2] + end + A +end + +function simple_promote_loop_mutable_repeated_twin(A) + for j in 1:10 + A[2, 1] = A[1, 2] + end + for j in 1:10 + A[2, 1] = A[1, 2] + end + A +end + +function simple_promote_loop_mutable_twin(A) + for k in 1:10 + A[2, k] = A[1, 2] + end + for j in 1:10 + A[2, 1] = A[1, 2] + end + A +end + +function simple_branch_for(A) + p = 1 + x = Int32(0) + for _ in axes(A, 2) + p = A[1, 1] + x += Int32(1) + end + return p, x +end + +function promote_loop_non_upgraded_slot(A, x) + p = 1 + for i in axes(A, 2) + p = A[i, 1] + x + end + return p +end + + +@testset "basic promotion" begin + n = 64 + a = Reactant.ConcreteRNumber(n) + @test @jit(promote_loop(a)) == promote_loop(n) + @test @jit(promote_loop_twin(a)) == promote_loop_twin(n) + @test @jit(promote_loop_slot_write(a)) == promote_loop_slot_write(n) + @test @jit(double_promote_loop_slot_write(a)) == double_promote_loop_slot_write(n) + + A = collect(reshape(1:100, 10, 10)) + tA = Reactant.to_rarray(A) + @test @jit(simple_promote_loop_mutable(tA)) == simple_promote_loop_mutable(A) + @test @jit(simple_promote_loop_mutable_repeated(tA)) == simple_promote_loop_mutable_repeated(A) + @test @jit(simple_promote_loop_mutable_repeated_twin(tA)) == simple_promote_loop_mutable_repeated_twin(A) + @test @jit(simple_branch_for(tA)) == simple_branch_for(A) + @test @jit(promote_loop_non_upgraded_slot(tA, a)) == promote_loop_non_upgraded_slot(A,n) +end + + + +function simple_traced_iterator(A) + a = A[1, 1] + p = 0 + for i in 1:a + p = i + end + return p +end + +function double_loop_traced_iterator(A) + p = 0 + for i in axes(A, 1) + for j in 1:i + p += A[i, j] + end + end + return p +end + +function simple_reverse_iterator(a) + p = 0 + for i in a:-1:1 + p += i + end + return p +end + +@testset "basic traced iterator" begin + n = 32 + a = Reactant.ConcreteRNumber(n) + A = collect(reshape(1:4, 2, 2)) + tA = Reactant.to_rarray(A) + @test @jit(simple_traced_iterator(tA)) == simple_traced_iterator(A) + @test @jit(double_loop_traced_iterator(tA)) == double_loop_traced_iterator(A) + @test @jit(simple_reverse_iterator(a)) == simple_reverse_iterator(n) +end + + +function basic_if(c) + r = c ? 1 : 0 + r +end + +function normalized_if(c) + c ? 1 : 0 +end + +function slot_if(c) + a = 0 + if c + a += 1 + else + a -= 1 + end + a +end + +function partial_if(c) + a1 = a2 = a3 = a4 = 0 + if c + a1 = 1 + a2 = 2 + else + a3 = 3 + a4 = 4 + end + a1 + a2 + a3 + a4 +end + +function asymetric_slot_if(c) + a = 0 + if c + a += 1 + end + a +end + +function asymetric_slot__argument_if(c, b) + a = 0 + if c + a += 1 + b = 2 + else + a -= 1 + end + a + b +end + +function mutable_if(c, A) + if c + A[1] += 1 + end + A +end + +function mutable_both_if(c, A) + if c + A[1] += 1 + else + A[2] -= 1 + end + A +end + +function multiple_layer(x) + if x > 10 + x > 15 ? 5 : 17 + else + 34 + end +end + +@testset "basic if" begin + v = true + a = Reactant.ConcreteRNumber(v) + n = Reactant.ConcreteRNumber(16) + A = collect(reshape(1:4, 4)) + tA = Reactant.to_rarray(A) + @test @jit(basic_if(a)) == basic_if(v) + @test @jit(normalized_if(a)) == normalized_if(v) + @test @jit(slot_if(a)) == slot_if(v) + @test @jit(partial_if(a)) == partial_if(v) + @test @jit(asymetric_slot_if(a)) == asymetric_slot_if(v) + @test @jit(asymetric_slot__argument_if(a,n)) == asymetric_slot__argument_if(v,n) + @test @jit(mutable_if(a, tA)) == mutable_if(v, A) + @test @jit(mutable_both_if(a, tA)) == mutable_both_if(v, A) + @test @jit(multiple_layer(a)) == multiple_layer(v) +end \ No newline at end of file diff --git a/test/auto_cf/polybench.jl b/test/auto_cf/polybench.jl new file mode 100644 index 0000000000..d25d78fb80 --- /dev/null +++ b/test/auto_cf/polybench.jl @@ -0,0 +1,791 @@ +module Polybench +using Reactant + +function kernel_correlation(D) + m = size(D, 2) + mean = zeros(eltype(D), m)::Reactant.TracedRArray{Float64,1} + stddev = zeros(eltype(D), m)::Reactant.TracedRArray{Float64,1} + corr = zeros(eltype(D), m, m)::Reactant.TracedRArray{Float64,2} + for j in axes(D, 2) + mean[j] = 0.0 + for i in axes(D, 1) + mean[j] += D[i, j] + end + mean[j] /= size(D, 1) + end + + for j in axes(D, 2) + stddev[j] = 0.0 + for i in axes(D, 1) + stddev[j] += (D[i, j] - mean[j]) * (D[i, j] - mean[j]) + end + stddev[j] /= size(D, 1) + stddev[j] = sqrt(stddev[j]) + stddev[j] = stddev[j] <= 0.1 ? 1.0 : stddev[j] + end + + for i in axes(D, 1) + for j in axes(D, 2) + D[i, j] -= mean[j] + D[i, j] /= sqrt(size(D, 1)) * stddev[j] + end + end + + for i in axes(D, 2) + corr[i, i] = 1.0 + for j in (i + 1):m + corr[i, j] = 0.0 + for k in axes(D, 1) + corr[i, j] += D[k, i] * D[k, j] + end + corr[j, i] = corr[i, j] + end + end + corr[m, m] = 1.0 + return corr +end + +function kernel_covariance(D) + m = size(D, 2) + mean = zeros(eltype(D), m) + cov = zeros(eltype(D), m, m) + for j in axes(D, 2) + mean[j] = 0.0 + for i in axes(D, 1) + mean[j] += D[i, j] + end + mean[j] /= size(D, 1) + end + + for i in axes(D, 1) + for j in axes(D, 2) + D[i, j] -= mean[j] + end + end + + for i in axes(D, 2) + for j in axes(D, 2) + cov[i, j] = 0.0 + for k in axes(D, 1) + cov[i, j] += D[k, i] * D[k, j] + end + cov[i, j] /= size(D, 1) - 1.0 + cov[i, j] = cov[j, i] + end + end + return cov +end + +function kernel_gemm(alpha, beta, A, B) + C = zeros(eltype(A), axes(A, 1), axes(B, 2)) + for i in axes(A, 1) + for j in axes(B, 2) + C[i, j] *= beta + end + for k in axes(A, 2) + for j in axes(B, 2) + C[i, j] += alpha * A[i, k] * B[k, j] + end + end + end + return C +end + +function kernel_gemmver(alpha, beta, u1, u2, v1, v2, A, x, y, z) + w = zeros(eltype(A), axes(A, 1)) + for i in axes(A, 1) + for j in axes(A, 2) + A[i, j] = A[i, j] + u1[i] * v1[j] + u2[i] * v2[j] + end + end + + for i in axes(A, 1) + for j in axes(A, 2) + x[i] = x[i] + beta * A[j, i] * y[j] + end + end + + for i in axes(A, 1) + x[i] = x[i] + z[i] + end + + for i in axes(A, 1) + for j in axes(A, 2) + w[i] = w[i] + alpha * A[i, j] * x[j] + end + end + return w +end + +function kernel_gesummv(alpha, beta, A, B, x) + tmp = zeros(eltype(A), axes(A, 1)) + y = zeros(eltype(A), axes(A, 1)) + for i in axes(A, 1) + tmp[i] = 0.0 + y[i] = 0.0 + for j in axes(A, 2) + tmp[i] = A[i, j] * x[j] + tmp[i] + y[i] = B[i, j] * x[j] + y[i] + end + y[i] = alpha * tmp[i] + beta * y[i] + end + return y +end + +function kernel_symm(alpha, beta, A, B) + C = zeros(eltype(A), axes(A, 1), axes(B, 2)) + for i in axes(A, 1) + for j in axes(B, 2) + temp2 = 0.0 + for k in axes(A, 2) + C[k, j] += alpha * B[i, j] * A[i, k] + temp2 += B[k, j] * A[i, k] + end + C[i, j] = beta * C[i, j] + alpha * B[i, j] * A[i, i] + alpha * temp2 + end + end + return C +end + +function kernel_syr2k(alpha, beta, A, B) + C = zeros(eltype(A), axes(A, 1), axes(A, 1)) + for i in axes(A, 1) + for j in 1:i + C[i, j] *= beta + end + for k in axes(A, 2) + for j in 1:i + C[i, j] += A[j, k] * alpha * B[i, k] + B[j, k] * alpha * A[i, k] + end + end + end + return C +end + +function kernel_syrk(alpha, beta, A) + C = zeros(eltype(A), axes(A, 1), axes(A, 1)) + for i in axes(A, 1) + for j in 1:i + C[i, j] *= beta + end + for k in axes(A, 2) + for j in 1:i + C[i, j] += alpha * A[i, k] * A[j, k] + end + end + end + return C +end + +function kernel_trmm(alpha, A, B) + for i in axes(A, 1) + for j in axes(B, 2) + for k in axes(A, 2) + B[i, j] += A[k, i] * B[k, j] + end + B[i, j] = alpha * B[i, j] + end + end + return B +end + +function kernel_2mm(alpha, beta, A, B, C, D) + tmp = zeros(eltype(A), axes(A, 1), axes(B, 2)) + for i in axes(A, 1) + for j in axes(B, 2) + tmp[i, j] = 0 + for k in axes(A, 2) + tmp[i, j] += alpha * A[i, k] * B[k, j] + end + end + end + for i in axes(A, 1) + for j in axes(C, 2) + D[i, j] *= beta + for k in axes(B, 1) + D[i, j] += tmp[i, k] * C[k, j] + end + end + end + return D +end + +function kernel_3mm(A, B, C, D) + E = zeros(eltype(A), axes(A, 1), axes(A, 2)) + F = zeros(eltype(A), axes(A, 1), axes(A, 2)) + G = zeros(eltype(A), axes(A, 1), axes(A, 2)) + for i in axes(A, 1) + for j in axes(B, 2) + E[i, j] = 0.0 + for k in axes(A, 2) + E[i, j] += A[i, k] * B[k, j] + end + end + end + + for i in axes(C, 1) + for j in axes(D, 2) + F[i, j] = 0.0 + for k in axes(C, 2) + F[i, j] += C[i, k] * D[k, j] + end + end + end + + for i in axes(E, 1) + for j in axes(F, 2) + G[i, j] = 0.0 + for k in axes(E, 2) + G[i, j] += E[i, k] * F[k, j] + end + end + end + return G +end + +function kernel_atax(A, x) + tmp = zeros(eltype(A), axes(A, 1)) + y = zeros(eltype(A), axes(A, 2)) + + for i in axes(A, 2) + y[i] = 0 + end + for i in axes(A, 1) + tmp[i] = 0 + for j in axes(A, 2) + tmp[i] = tmp[i] + A[i, j] * x[j] + end + for j in axes(A, 2) + y[j] = y[j] + A[i, j] * tmp[i] + end + end + return y +end + +function kernel_bicg(A, p, r) + s = zeros(eltype(A), axes(A, 2)) + q = zeros(eltype(A), axes(A, 1)) + + for i in axes(A, 2) + s[i] = 0 + end + for i in axes(A, 1) + q[i] = 0 + for j in axes(A, 2) + s[j] = s[j] + r[i] * A[i, j] + q[i] = q[i] + A[i, j] * p[j] + end + end + return s, q +end + +function kernel_doitgen(A, C4) + sum = similar(A, size(A, 2)) + for r in axes(A, 1) + for q in axes(A, 2) + for p in axes(A, 2) + sum[p] = 0.0 + for s in axes(A, 2) + sum[p] += A[r, q, s] * C4[s, p] + end + end + for p in axes(A, 2) + A[r, q, p] = sum[p] + end + end + end + return A +end + +function kernel_mvt(x1, x2, y1, y2, A) + for i in axes(x1, 1) + for j in axes(y1, 1) + x1[i] += A[i, j] * y1[j] + end + end + + for i in axes(x2, 1) + for j in axes(y2, 1) + x2[i] += A[i, j] * y2[j] + end + end + + return x1, x2 +end + +function kernel_cholesky(A) + for i in axes(A, 1) + for j in 1:(i - 1) + for k in 1:(j - 1) + A[i, j] -= A[i, k] * A[j, k] + end + A[i, j] /= A[j, j] + end + + for k in 1:(i - 1) + A[i, i] -= A[i, k] * A[i, k] + end + A[i, i] = sqrt(A[i, i]) + end + return A +end + +function kernel_durbin(r, y) + y[1] = -r[1] + beta = 1.0 + alpha = -r[1] + z = zero(y) + for k in 2:size(y, 1) + beta = (1 - alpha * alpha) * beta + sum = 0.0 + for i in 1:k + sum += r[k - i] * y[i] + end + alpha = -(r[k] + sum) / beta + + for i in 1:k + z[i] = y[i] + alpha * y[k - i] + end + + for i in 1:k + y[i] = z[i] + end + y[k] = alpha + end + return y +end + +function kernel_gramschidt(A, R, Q) + for k in axes(A, 2) + nrm = 0.0 + for i in axes(A, 1) + nrm += A[i, k]^2 + end + R[k, k] = sqrt(nrm) + for i in axes(Q, 1) + Q[i, k] = A[i, k] / R[k, k] + end + + for j in (k + 1):size(R, 1) + R[k, j] = 0.0 + for i in axes(A, 1) + R[k, j] += Q[i, k] * A[i, j] + end + for i in axes(A, 1) + A[i, j] -= Q[i, k] * R[k, j] + end + end + end + return A +end + +function kernel_lu(A) + for i in axes(A, 1) + for j in 1:(i - 1) + for k in 1:(j - 1) + A[i, j] -= A[i, k] * A[k, j] + end + A[i, j] /= A[j, j] + end + for j in axes(A, 1) + for k in 1:(i - 1) + A[i, j] -= A[i, k] * A[k, j] + end + end + end + return A +end + +function kernel_ludcmp(A, b, x, y) + for i in axes(A, 1) + for j in 1:(i - 1) + w = A[i, j] + for k in 1:(j - 1) + w -= A[i, k] * A[k, j] + end + A[i, j] = w / A[j, j] + end + for j in 1:(i - 1) + w = A[i, j] + for k in 1:(j - 1) + w -= A[i, k] * A[k, j] + end + A[i, j] = w + end + end + + for i in axes(b, 1) + w = b[i] + for j in 1:(i - 1) + w -= A[i, j] * y[j] + end + y[i] = w + end + + for i in size(y, 1):-1:1 + w = y[i] + for j in (i + 1):size(A, 2) + w -= A[i, j] * x[j] + end + x[i] = w / A[i, i] + end + return A +end + +function kernel_trisolv(L, x, b) + for i in axes(x, 1) + x[i] = b[i] + for j in 1:(i - 1) + x[i] -= L[i, j] * x[j] + end + x[i] = x[i] / L[i, i] + end + return x +end + +function kernel_adi(tsteps::Int, u, v, p, q) + N = size(u, 1) + DX = 1.0 / N + DY = 1.0 / N + DT = 1.0 / tsteps + B1 = 2.0 + B2 = 1.0 + mul1 = B1 * DT / (DX * DX) + mul2 = B2 * DT / (DY * DY) + + a = -mul1 / 2.0 + b = 1.0 + mul1 + c = a + d = -mul2 / 2.0 + e = 1.0 + mul2 + f = d + + for t in 1:tsteps + #//Column Sweep + for i in 2:(N - 1) + v[1, i] = 1.0 + p[i, 1] = 0.0 + q[i, 1] = v[1, i] + for j in 2:(N - 1) + p[i, j] = -c / (a * p[i, j - 1] + b) + q[i, j] = + ( + -d * u[j, i - 1] + (1.0 + 2.0 * d) * u[j, i] - f * u[j, i + 1] - + a * q[i, j - 1] + ) / (a * p[i, j - 1] + b) + end + + v[N, i] = 1.0 + for j in (N - 1):-1:2 + v[j, i] = p[i, j] * v[j + 1, i] + q[i, j] + end + end + #//Row Sweep + for i in 2:(N - 1) + u[i, 1] = 1.0 + p[i, 1] = 0.0 + q[i, 1] = u[i, 1] + for j in 2:(N - 1) + p[i, j] = -f / (d * p[i, j - 1] + e) + q[i, j] = + ( + -a * v[i - 1, j] + (1.0 + 2.0 * a) * v[i, j] - c * v[i + 1, j] - + d * q[i, j - 1] + ) / (d * p[i, j - 1] + e) + end + u[i, N] = 1.0 + for j in (N - 1):-1:2 + u[i, j] = p[i, j] * u[i, j + 1] + q[i, j] + end + end + end + return u, v +end + +function kernel_fdtd_2d(EX, EY, HZ, fict) + for t in axes(fict, 1) + for j in axes(EY, 2) + EY[1, j] = fict[t] + end + for i in 2:size(EY, 1) + for j in axes(EY, 2) + EY[i, j] = EY[i, j] - 0.5 * (HZ[i, j] - HZ[i - 1, j]) + end + end + for i in axes(EX, 1) + for j in 2:size(EX, 2) + EX[i, j] = EX[i, j] - 0.5 * (HZ[i, j] - HZ[i, j - 1]) + end + end + for i in 1:(size(HZ, 1) - 1) + for j in 1:(size(HZ, 2) - 1) + HZ[i, j] = + HZ[i, j] - 0.7 * (EX[i, j + 1] - EX[i, j] + EY[i + 1, j] - EY[i, j]) + end + end + end + + return HZ +end + +function kernel_heat_3d(tsteps::Int, A, B) + N = size(A, 1) + + for t in 1:tsteps + for i in 2:(N - 1) + for j in 2:(N - 1) + for k in 2:(N - 1) + B[i, j, k] = + 0.125 * (A[i + 1, j, k] - 2.0 * A[i, j, k] + A[i - 1, j, k]) + +0.125 * (A[i, j + 1, k] - 2.0 * A[i, j, k] + A[i, j - 1, k]) + +0.125 * (A[i, j, k + 1] - 2.0 * A[i, j, k] + A[i, j, k - 1]) + +A[i, j, k] + end + end + end + for i in 2:(N - 1) + for j in 2:(N - 1) + for k in 2:(N - 1) + A[i, j, k] = + 0.125 * (B[i + 1, j, k] - 2.0 * B[i, j, k] + B[i - 1, j, k]) + +0.125 * (B[i, j + 1, k] - 2.0 * B[i, j, k] + B[i, j - 1, k]) + +0.125 * (B[i, j, k + 1] - 2.0 * B[i, j, k] + B[i, j, k - 1]) + +B[i, j, k] + end + end + end + end + return A, B +end + +function kernel_jacobi_1d(tsteps::Int, A, B) + N = size(A, 1) + for t in 1:tsteps + for i in 2:(N - 1) + B[i] = 1 / 3 * (A[i - 1] + A[i] + A[i + 1]) + end + for i in 2:(N - 1) + A[i] = 1 / 3 * (B[i - 1] + B[i] + B[i + 1]) + end + end + return A, B +end + +function kernel_jacobi_2d(tsteps::Int, A, B) + N = size(A, 1) + + for t in 1:tsteps + for i in 2:(N - 1) + for j in 2:(N - 1) + B[i, j] = + 0.2 * (A[i, j] + A[i, j - 1] + A[i, 1 + j] + A[1 + i, j] + A[i - 1, j]) + end + end + for i in 2:(N - 1) + for j in 2:(N - 1) + A[i, j] = + 0.2 * (B[i, j] + B[i, j - 1] + B[i, 1 + j] + B[1 + i, j] + B[i - 1, j]) + end + end + end + return A, B +end + +function kernel_seidel_2d(tsteps::Int, A) + N = size(A, 1) + for _ in 1:tsteps + for i in 2:(N - 1) + for j in 2:(N - 1) + A[i, j] = + ( + A[i - 1, j - 1] + + A[i - 1, j] + + A[i - 1, j + 1] + + A[i, j - 1] + + A[i, j] + + A[i, j + 1] + + A[i + 1, j - 1] + + A[i + 1, j] + + A[i + 1, j + 1] + ) / 9.0 + end + end + end + return A +end + +function kernel_deriche(alpha, I) + Y1 = zero(I) + Y2 = zero(I) + O = zero(I) + k = (1.0 - exp(-alpha))^2 / (1.0 + 2.0 * exp(-alpha) - exp(2.0 * alpha)) + a1 = a5 = k + a2 = a6 = k * exp(-alpha) * (alpha - 1.0) + a3 = a7 = k * exp(-alpha) * (alpha + 1.0) + a4 = a8 = -k * exp(-2.0 * alpha) + b1 = 2.0^(-alpha) + b2 = -exp(-2.0 * alpha) + c1 = c2 = 1 + + for i in axes(I, 1) + ym1 = ym2 = xm1 = 0.0 + for j in axes(I, 2) + Y1[i, j] = a1 * I[i, j] + a2 * xm1 + b1 * ym1 + b2 * ym2 + xm1 = I[i, j] + ym2 = ym1 + ym1 = Y1[i, j] + end + end + + for i in axes(I, 1) + yp1 = yp2 = xp1 = xp2 = 0.0 + for j in size(I, 2):-1:1 + Y2[i, j] = a3 * xp1 + a4 * xp2 + b1 * yp1 + b2 * yp2 + xp2 = xp1 + xp1 = I[i, j] + yp1 = Y2[i, j] + end + end + + for i in axes(I, 1) + for j in axes(I, 2) + O[i, j] = c1 * (Y1[i, j] + Y2[i, j]) + end + end + + for j in axes(I, 2) + tm1 = ym1 = ym2 = 0.0 + for i in axes(I, 1) + Y1[i, j] = a5 * O[i, j] + a6 * tm1 + b1 * ym1 + b2 * ym2 + tm1 = O[i, j] + ym2 = ym1 + ym1 = Y1[i, j] + end + end + + for j in axes(I, 2) + tp1 = tp2 = yp1 = yp2 = 0.0 + for i in size(I, 1):-1:1 + Y2[i, j] = a7 * tp1 + a8 * tp2 + b1 * yp1 + b2 * yp2 + tp2 = tp1 + tp1 = O[i, j] + yp2 = yp1 + yp1 = Y2[i, j] + end + end + + for i in axes(I, 1) + for j in axes(I, 2) + O[i, j] = c2 * (Y1[i, j] + Y2[i, j]) + end + end + return O +end + +function kernel_floyd_warshall(path) + for k in axes(path, 1) + for i in axes(path, 1) + for j in axes(path, 1) + path[i, j] = if path[i, j] < path[i, k] + path[k, j] + path[i, j] + else + path[i, k] + path[k, j] + end + end + end + end + return path +end + +function kernel_nussinov(seq, T) + for i in size(seq, 1):-1:1 + for j in (i + 1):size(seq, 1) + if j >= 1 + T[i, j] = max(T[i, j], T[i, j - 1]) + end + + if i + 1 <= size(seq, 1) + T[i, j] = max(T[i, j], T[i + 1, j]) + end + + if j >= 1 && (i + 1 <= size(seq, 1)) + if i < j - 1 + tmp = (((seq[i] + seq[j]) - 3.0) < 10e-5) ? 1.0 : 0.0 + T[i, j] = max(T[i, j], T[i + 1, j - 1] + tmp) + else + T[i, j] = max(T[i, j], T[i + 1, j - 1]) + end + end + + for k in (i + 1):(j - 1) + T[i, j] = max(T[i, j], T[i, j] + T[k + 1, j]) + end + end + end + return T +end + +end + +function kernel_deriche2(I) + Y1 = zero(I) + ym1 = 0.0 + ym2 = 0.0 + i = 1 + for j in axes(I, 2) + Y1[i, j] = I[i, j] + ym1 + ym2 + ym2 = ym1 + 1.0 + ym1 = Y1[i, j] + end + return Y1 +end + +@testset "Polybench" begin + A = collect(reshape(1:1.0:256, 16, 16)) + c = collect(reshape(1:16, 1, 16)) + tA = Reactant.to_rarray(A) + f = collect(reshape(0:0.065:1, 1, 16)) + tf = Reactant.to_rarray(f) + + @jit(kernel_correlation(tA)) + @jit(kernel_correlation(tA)) + + @test @jit(kernel_gemm(5, 2, tA, tA)) == kernel_gemm(5, 2, A, A) + @test @jit(kernel_gemmver(2.0, 0.01, tf, tf, tf, tf, tA, tf, tf, tf)) β‰ˆ + kernel_gemmver(2.0, 0.01, f, f, f, f, A, f, f, f) + @test @jit(kernel_gesummv(2.0, 0.01, tA, tA, tf)) β‰ˆ kernel_gesummv(2.0, 0.01, A, A, f) + @test @jit(kernel_symm(2.0, 0.01, tA, tA)) β‰ˆ kernel_symm(2.0, 0.01, A, A) + @test @jit(kernel_syr2k(2.0, 0.01, tA, tA)) β‰ˆ kernel_syr2k(2.0, 0.01, A, A) + @test @jit(kernel_syrk(2.0, 0.01, tA)) β‰ˆ kernel_syrk(2.0, 0.01, A) + @test @jit(kernel_trmm(2.0, tA, tA)) β‰ˆ kernel_trmm(2.0, A, A) + @test @jit(kernel_2mm(2.0, 0.01, tA, tA, tA, tA)) β‰ˆ kernel_2mm(2.0, 0.01, A, A, A, A) + @test @jit(kernel_3mm(tA, tA, tA, tA)) β‰ˆ kernel_3mm(A, A, A, A) + @test @jit(kernel_atax(tA, tf)) β‰ˆ kernel_atax(A, f) + @test all((@jit(kernel_bicg(tA, tf, tf)) .β‰ˆ kernel_bicg(A, f, f))) + AA = ones(16, 16, 16) + tAA = Reactant.to_rarray(AA) + @test @jit((kernel_doitgen(tAA, tA))) β‰ˆ kernel_doitgen(AA, A) + @test all((@jit(kernel_mvt(tf, tf, tf, tf, tA)) .β‰ˆ kernel_mvt(f, f, f, f, A))) + @test @jit(kernel_cholesky(tA)) β‰ˆ kernel_cholesky(A) + @test @jit(kernel_durbin(tf, tf)) β‰ˆ kernel_durbin(f, f) + @test @jit(kernel_gramschidt(tA, tA, tA)) β‰ˆ kernel_gramschidt(A, A, A) + @test @jit(kernel_lu(tA)) β‰ˆ kernel_lu(A) + @test @jit(kernel_ludcmp(tA, tA, tf, tf)) β‰ˆ kernel_ludcmp(A, A, f, f) + @test @jit(kernel_trisolv(tA, tf, tf)) β‰ˆ kernel_trisolv(A, f, f) + tA1 = Reactant.to_rarray(A) + tA2 = Reactant.to_rarray(A) + tA3 = Reactant.to_rarray(A) + tA4 = Reactant.to_rarray(A) + @test all( + @jit(kernel_adi(5, tA1, tA2, tA3, tA4)) .β‰ˆ + kernel_adi(5, copy(A), copy(A), copy(A), copy(A)), + ) + @test @jit(kernel_fdtd_2d(tA, tA, tA, tf)) β‰ˆ kernel_fdtd_2d(A, A, A, f) + @test all(@jit(kernel_heat_3d(5, tAA, tAA)) .β‰ˆ kernel_heat_3d(5, AA, AA)) + @test all(@jit(kernel_jacobi_1d(5, tf, tf)) .β‰ˆ kernel_jacobi_1d(5, f, f)) + @test all(@jit(kernel_jacobi_2d(5, tA, tA)) .β‰ˆ kernel_jacobi_2d(5, A, A)) + @test all(@jit(kernel_seidel_2d(5, tA)) .β‰ˆ kernel_seidel_2d(5, A)) + @test all(@jit(kernel_deriche(0.1, tA)) .β‰ˆ kernel_deriche(0.1, A)) #TODO: FAIL + @test @jit(kernel_floyd_warshall(tA)) == kernel_floyd_warshall(A) + @test @jit(kernel_nussinov(tf, tA)) β‰ˆ kernel_nussinov(f, A) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index f812deee5c..2e3dba3c9c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,6 +39,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Config" include("config.jl") @safetestset "Batching" include("batching.jl") @safetestset "QA" include("qa.jl") + @static (VERSION < v"1.12" &&VERSION > v"1.11") && begin + @safetestset "Automatic Control Flow" include("auto_cf/basic.jl") + + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"