Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ function set_reactant_abi(
)
end


current_interpreter = Ref{Enzyme.Compiler.Interpreter.EnzymeInterpreter{typeof(Reactant.set_reactant_abi)}}()
@static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE
struct ReactantCacheToken end

function ReactantInterpreter(; world::UInt=Base.get_world_counter())
return Enzyme.Compiler.Interpreter.EnzymeInterpreter(
current_interpreter[] = Enzyme.Compiler.Interpreter.EnzymeInterpreter(
ReactantCacheToken(),
REACTANT_METHOD_TABLE,
world,
Expand All @@ -108,7 +110,7 @@ else
function ReactantInterpreter(;
world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE
)
return Enzyme.Compiler.Interpreter.EnzymeInterpreter(
current_interpreter[] = Enzyme.Compiler.Interpreter.EnzymeInterpreter(
REACTANT_CACHE,
REACTANT_METHOD_TABLE,
world,
Expand Down
2 changes: 1 addition & 1 deletion src/Precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end

# Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947
function precompilation_supported()
return (VERSION >= v"1.11" || VERSION >= v"1.10.8") && (VERSION < v"1.12-")
return false && (VERSION >= v"1.11" || VERSION >= v"1.10.8") && (VERSION < v"1.12-")
end

if Reactant_jll.is_available()
Expand Down
3 changes: 2 additions & 1 deletion src/TracedRange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,10 @@ function Base._reshape(parent::TracedUnitRange, dims::Dims)
return Base.__reshape((parent, IndexStyle(parent)), dims)
end

function (C::Base.Colon)(start::TracedRNumber{T}, stop::TracedRNumber{T}) where {T}
#=function (C::Base.Colon)(start::TracedRNumber{T}, stop::TracedRNumber{T}) where {T}
return TracedUnitRange(start, stop)
end
=#
function (C::Base.Colon)(start::TracedRNumber{T}, stop::T) where {T}
return C(start, TracedRNumber{T}(stop))
end
Expand Down
6 changes: 6 additions & 0 deletions src/auto_cf/AutoCF.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
include("debug_utils.jl")
include("new_inference.jl")
include("code_info_mut.jl")
include("code_ir_utils.jl")
include("mlir_utils.jl")
include("code_gen.jl")
84 changes: 84 additions & 0 deletions src/auto_cf/analysis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
@enum UpgradeSlot NoUpgrade UpgradeLocally UpgradeDefinition UpgradeDefinitionGlobal EqualizeBranches

@enum State Traced Upgraded Maybe NotTraced

mutable struct ForStructure
accus::Tuple
header_bb::Int
latch_bb::Int
terminal_bb::Int
body_bbs::Set{Int}
state::State
end

struct IfStructure
ssa_cond
header_bb::Int
terminal_bb::Int
true_bbs::Set{Int}
false_bbs::Set{Int}
owned_true_bbs::Set{Int}
owned_false_bbs::Set{Int}
legalize::Ref{Bool} #inform that the if traced GotoIfNot can pass type inference
unbalanced_slots::Set{Core.SlotNumber}
end

mutable struct SlotAnalysis
slot_stmt_def::Vector{Integer} #0 for argument
slot_bb_usage::Vector{Set{Int}}
end


CFStructure = Union{IfStructure,ForStructure}
mutable struct Tree
node::Union{Nothing,Base.uniontypes(CFStructure)...}
children::Vector{Tree}
parent::Ref{Tree}
end

Base.isempty(tree::Tree) = isnothing(tree.node) && length(tree.children) == 0

Base.show(io::IO, t::Tree) = begin
Base.print(io, '(')
Base.show(io, t.node)
Base.print(io, ',')
Base.show(io, t.children)
Base.print(io, ')')
end

mutable struct Analysis
tree::Tree
domtree::Union{Nothing,Vector{CC.DomTreeNode}}
postdomtree::Union{Nothing,Vector{CC.DomTreeNode}}
slotanalysis::Union{Nothing,SlotAnalysis}
pending_tree::Union{Nothing,Tree}
end


#leak each argument to a global variable
macro lk(args...)
quote
$([:(
let val = $(esc(p))
global $(esc(p)) = val
end
) for p in args]...)
end
end





MethodInstanceKey = Vector{Type}
function mi_key(mi::Core.MethodInstance)
return collect(Base.unwrap_unionall(mi.specTypes).parameters)
end
@kwdef struct MetaData
traced_tree_map::Dict{MethodInstanceKey,Tree} = Dict()
end

meta = Ref(MetaData())
function get_meta(_::Reactant.ReactantInterp)::MetaData
meta[]
end
Loading
Loading