@@ -55,8 +55,11 @@ struct ADInterpreter <: AbstractInterpreter
5555 unopt:: Union{OffsetVector{UnoptCache},Nothing}
5656 transformed:: OffsetVector{OptCache}
5757
58+ # Cache results for forward inference over a converged inference (current_level == missing)
59+ generic:: OptCache
60+
5861 native_interpreter:: NativeInterpreter
59- current_level:: Int
62+ current_level:: Union{ Int, Missing}
6063 remarks:: OffsetVector{RemarksCache}
6164
6265 function _ADInterpreter ()
@@ -66,6 +69,7 @@ struct ADInterpreter <: AbstractInterpreter
6669 #= opt::OffsetVector{OptCache}=# OffsetVector ([OptCache (), OptCache ()], 0 : 1 ),
6770 #= unopt::Union{OffsetVector{UnoptCache},Nothing}=# OffsetVector ([UnoptCache (), UnoptCache ()], 0 : 1 ),
6871 #= transformed::OffsetVector{OptCache}=# OffsetVector ([OptCache (), OptCache ()], 0 : 1 ),
72+ OptCache (),
6973 #= native_interpreter::NativeInterpreter=# NativeInterpreter (),
7074 #= current_level::Int=# 0 ,
7175 #= remarks::OffsetVector{RemarksCache}=# OffsetVector ([RemarksCache ()], 0 : 0 ))
@@ -76,10 +80,11 @@ struct ADInterpreter <: AbstractInterpreter
7680 opt:: OffsetVector{OptCache} = interp. opt,
7781 unopt:: Union{OffsetVector{UnoptCache},Nothing} = interp. unopt,
7882 transformed:: OffsetVector{OptCache} = interp. transformed,
83+ generic:: OptCache = interp. generic,
7984 native_interpreter:: NativeInterpreter = interp. native_interpreter,
80- current_level:: Int = interp. current_level,
85+ current_level:: Union{ Int, Missing} = interp. current_level,
8186 remarks:: OffsetVector{RemarksCache} = interp. remarks)
82- return new (forward, backward, opt, unopt, transformed, native_interpreter, current_level, remarks)
87+ return new (forward, backward, opt, unopt, transformed, generic, native_interpreter, current_level, remarks)
8388 end
8489end
8590
@@ -89,6 +94,27 @@ lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level -
8994
9095disable_forward (interp:: ADInterpreter ) = ADInterpreter (interp; forward= false )
9196
97+ function CC. InferenceState (result:: InferenceResult , cache:: Symbol , interp:: ADInterpreter )
98+ if interp. current_level === missing
99+ error ()
100+ end
101+ return @invoke CC. InferenceState (result:: InferenceResult , cache:: Symbol , interp:: AbstractInterpreter )
102+ # prepare an InferenceState object for inferring lambda
103+ world = get_world_counter (interp)
104+ src = retrieve_code_info (result. linfo, world)
105+ src === nothing && return nothing
106+ validate_code_in_debug_mode (result. linfo, src, " lowered" )
107+ return InferenceState (result, src, cache, interp, Bottom)
108+ end
109+
110+
111+ function CC. initial_bestguess (interp:: ADInterpreter , result:: InferenceResult )
112+ if interp. current_level === missing
113+ return CC. typeinf_lattice (interp. native_interpreter, result. linfo)
114+ end
115+ return Bottom
116+ end
117+
92118function Cthulhu. get_optimized_codeinst (interp:: ADInterpreter , curs:: ADCursor )
93119 @show curs
94120 (curs. transformed ? interp. transformed : interp. opt)[curs. level][curs. mi]
@@ -335,15 +361,6 @@ function CC.inlining_policy(interp::ADInterpreter,
335361 nothing , info:: CC.CallInfo , stmt_flag:: UInt8 , mi:: MethodInstance , argtypes:: Vector{Any} )
336362end
337363
338- # TODO remove this overload once https://github.com/JuliaLang/julia/pull/49191 gets merged
339- function CC. abstract_call_gf_by_type (interp:: ADInterpreter , @nospecialize (f),
340- arginfo:: ArgInfo , si:: StmtInfo , @nospecialize (atype),
341- sv:: IRInterpretationState , max_methods:: Int )
342- return @invoke CC. abstract_call_gf_by_type (interp:: AbstractInterpreter , f:: Any ,
343- arginfo:: ArgInfo , si:: StmtInfo , atype:: Any ,
344- sv:: CC.AbsIntState , max_methods:: Int )
345- end
346-
347364#=
348365function CC.optimize(interp::ADInterpreter, opt::OptimizationState,
349366 params::OptimizationParams, caller::InferenceResult)
0 commit comments