@@ -100,15 +100,15 @@ _augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) =
100100
101101function _create_tape_kernel (
102102 kernel:: Kernel{CPU} ,
103- ModifiedBetween ,
103+ Mode ,
104104 FT,
105105 ctxTy,
106106 ndrange,
107107 iterspace,
108108 args2... ,
109109 )
110110 TapeType = EnzymeCore. tape_type (
111- ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween) ,
111+ Mode ,
112112 FT,
113113 Const{Nothing},
114114 Const{ctxTy},
121121
122122function _create_tape_kernel (
123123 kernel:: Kernel{<:GPU} ,
124- ModifiedBetween ,
124+ Mode ,
125125 FT,
126126 ctxTy,
127127 ndrange,
@@ -139,7 +139,7 @@ function _create_tape_kernel(
139139 EnzymeCore. compiler_job_from_backend (backend (kernel), typeof (() -> return ), Tuple{})
140140 TapeType = EnzymeCore. tape_type (
141141 job,
142- ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween) ,
142+ Mode ,
143143 FT,
144144 Const{Nothing},
145145 Const{ctxTy},
@@ -159,14 +159,14 @@ _create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)
159159function cpu_aug_fwd (
160160 ctx,
161161 f:: FT ,
162- :: Val{ModifiedBetween} ,
162+ mode :: Mode ,
163163 subtape,
164164 :: Val{TapeType} ,
165165 args... ,
166- ) where {ModifiedBetween , FT, TapeType}
166+ ) where {Mode , FT, TapeType}
167167 # A2 = Const{Nothing} -- since f->Nothing
168168 forward, _ = EnzymeCore. autodiff_thunk (
169- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)) ,
169+ mode ,
170170 Const{Core. Typeof (f)},
171171 Const{Nothing},
172172 Const{Core. Typeof (ctx)},
@@ -183,13 +183,13 @@ end
183183function cpu_rev (
184184 ctx,
185185 f:: FT ,
186- :: Val{ModifiedBetween} ,
186+ mode :: Mode ,
187187 subtape,
188188 :: Val{TapeType} ,
189189 args... ,
190- ) where {ModifiedBetween , FT, TapeType}
190+ ) where {Mode , FT, TapeType}
191191 _, reverse = EnzymeCore. autodiff_thunk (
192- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)) ,
192+ mode ,
193193 Const{Core. Typeof (f)},
194194 Const{Nothing},
195195 Const{Core. Typeof (ctx)},
@@ -205,14 +205,14 @@ end
205205function gpu_aug_fwd (
206206 ctx,
207207 f:: FT ,
208- :: Val{ModifiedBetween} ,
208+ mode :: Mode ,
209209 subtape,
210210 :: Val{TapeType} ,
211211 args... ,
212- ) where {ModifiedBetween , FT, TapeType}
212+ ) where {Mode , FT, TapeType}
213213 # A2 = Const{Nothing} -- since f->Nothing
214214 forward, _ = EnzymeCore. autodiff_deferred_thunk (
215- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)) ,
215+ mode ,
216216 TapeType,
217217 Const{Core. Typeof (f)},
218218 Const{Nothing},
@@ -232,14 +232,14 @@ end
232232function gpu_rev (
233233 ctx,
234234 f:: FT ,
235- :: Val{ModifiedBetween} ,
235+ mode :: Mode ,
236236 subtape,
237237 :: Val{TapeType} ,
238238 args... ,
239- ) where {ModifiedBetween , FT, TapeType}
239+ ) where {Mode , FT, TapeType}
240240 # XXX : TapeType and A2 as args to autodiff_deferred_thunk
241241 _, reverse = EnzymeCore. autodiff_deferred_thunk (
242- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)) ,
242+ mode ,
243243 TapeType,
244244 Const{Core. Typeof (f)},
245245 Const{Nothing},
@@ -294,17 +294,17 @@ function EnzymeRules.augmented_primal(
294294 args[i]
295295 end
296296 end
297-
297+ Mode = EnzymeCore . set_runtime_activity ( ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween), config)
298298 TapeType, subtape, aug_kernel = _create_tape_kernel (
299299 kernel,
300- ModifiedBetween ,
300+ Mode ,
301301 FT,
302302 ctxTy,
303303 ndrange,
304304 iterspace,
305305 args2... ,
306306 )
307- aug_kernel (f, ModifiedBetween , subtape, Val (TapeType), args2... ; ndrange, workgroupsize)
307+ aug_kernel (f, Mode , subtape, Val (TapeType), args2... ; ndrange, workgroupsize)
308308
309309 # TODO the fact that ctxTy is type unstable means this is all type unstable.
310310 # Since custom rules require a fixed return type, explicitly cast to Any, rather
@@ -336,11 +336,11 @@ function EnzymeRules.reverse(
336336 f = kernel. f
337337
338338 ModifiedBetween = Val ((overwritten (config)[1 ], false , overwritten (config)[2 : end ]. .. ))
339-
339+ Mode = EnzymeCore . set_runtime_activity ( ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween), config)
340340 rev_kernel = _create_rev_kernel (kernel)
341341 rev_kernel (
342342 f,
343- ModifiedBetween ,
343+ Mode ,
344344 subtape,
345345 Val (tape_type),
346346 args2... ;
0 commit comments