@@ -274,3 +274,81 @@ function test_AD(AD::Symbol, k::MOKernel, dims=(in=3, out=2, obs=3))
274274 end
275275 end
276276end
277+
278+ function count_allocs (f, args... )
279+ stats = @timed f (args... )
280+ return Base. gc_alloc_count (stats. gcstats)
281+ end
282+
283+ """
284+ constant_allocs_heuristic(f, args1::T, args2::T) where {T}
285+
286+ True if number of allocations associated with evaluating `f(args1...)` is equal to those
287+ required to evaluate `f(args2...)`. Runs `f` beforehand to ensure that compilation-related
288+ allocations are not included.
289+
290+ Why is this a good test? In lots of situations it will be the case that the total amount of
291+ memory allocated by a function will vary as the input sizes vary, but the total _number_
292+ of allocations ought to be constant. A common performance bug is that the number of
293+ allocations actually does scale with the size of the inputs (e.g. due to a type
294+ instability), and we would very much like to know if this is happening.
295+
296+ Typically this kind of condition is not a sufficient condition for good performance, but it
297+ is certainly a necessary condition.
298+
299+ This kind of test is very quick to conduct (just requires running `f` 4 times). It's also
300+ easier to write than simply checking that the total number of allocations used to execute
301+ a function is below some arbitrary `f`-dependent threshold.
302+ """
303+ function constant_allocs_heuristic (f, args1:: T , args2:: T ) where {T}
304+
305+ # Ensure that we're not counting allocations associated with compilation.
306+ f (args1... )
307+ f (args2... )
308+
309+ allocs_1 = count_allocs (f, args1... )
310+ allocs_2 = count_allocs (f, args2... )
311+ return allocs_1 == allocs_2
312+ end
313+
314+ """
315+ ad_constant_allocs_heuristic(f, args1::T, args2::T; Δ1=nothing, Δ2=nothing) where {T}
316+
317+ Assesses `constant_allocs_heuristic` for `f`, `Zygote.pullback(f, args...)` and its
318+ pullback for both of `args1` and `args2`.
319+
320+ `Δ1` and `Δ2` are passed to the pullback associated with `Zygote.pullback(f, args1...)`
321+ and `Zygote.pullback(f, args2...)` respectively. If left as `nothing`, it is assumed that
322+ the output of the primal is an acceptable cotangent to be passed to the corresponding
323+ pullback.
324+ """
325+ function ad_constant_allocs_heuristic (
326+ f, args1:: T , args2:: T ; Δ1= nothing , Δ2= nothing
327+ ) where {T}
328+
329+ # Check that primal has constant allocations.
330+ primal_heuristic = constant_allocs_heuristic (f, args1, args2)
331+
332+ # Check that forwards-pass has constant allocations.
333+ forwards_heuristic = constant_allocs_heuristic (
334+ (args... ) -> Zygote. pullback (f, args... ), args1, args2
335+ )
336+
337+ # Check that pullback has constant allocations for both arguments. Run twice to remove
338+ # compilation-related allocations.
339+
340+ # First thing
341+ out1, pb1 = Zygote. pullback (f, args1... )
342+ Δ1_val = Δ1 === nothing ? out1 : Δ1
343+ pb1 (Δ1_val)
344+ allocs_1 = count_allocs (pb1, Δ1_val)
345+
346+ # Second thing
347+ out2, pb2 = Zygote. pullback (f, args2... )
348+ Δ2_val = Δ2 === nothing ? out2 : Δ2
349+ pb2 (Δ2_val)
350+ allocs_2 = count_allocs (pb2, Δ2 === nothing ? out2 : Δ2)
351+
352+ pullback_heuristic = allocs_1 == allocs_2
353+ return primal_heuristic, forwards_heuristic, pullback_heuristic
354+ end
0 commit comments