@@ -30,6 +30,13 @@ To be able to work with multi-threading, it should also implement:
3030- `split(acc::T)`
3131- `combine(acc::T, acc2::T)`
3232
33+ If two accumulators of the same type should be merged in some non-trivial way, other than
34+ always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined.
35+
36+ If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should
37+ do something other than copy the original accumulator, then
38+ `subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.`
39+
3340See the documentation for each of these functions for more details.
3441"""
3542abstract type AbstractAccumulator end
@@ -113,6 +120,24 @@ used by various AD backends, should implement a method for this function.
113120"""
114121convert_eltype (:: Type , acc:: AbstractAccumulator ) = acc
115122
123+ """
124+ subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName})
125+
126+ Return a new accumulator that only contains the information for the `VarName`s in `vns`.
127+
128+ By default returns a copy of `acc`. Subtypes should override this behaviour as needed.
129+ """
130+ subset (acc:: AbstractAccumulator , :: AbstractVector{<:VarName} ) = copy (acc)
131+
132+ """
133+ merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator)
134+
135+ Merge two accumulators of the same type. Returns a new accumulator of the same type.
136+
137+ By default returns a copy of `acc2`. Subtypes should override this behaviour as needed.
138+ """
139+ Base. merge (acc1:: AbstractAccumulator , acc2:: AbstractAccumulator ) = copy (acc2)
140+
116141"""
117142 AccumulatorTuple{N,T<:NamedTuple}
118143
@@ -158,6 +183,50 @@ function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N})
158183 return AccumulatorTuple (convert (T, accs. nt))
159184end
160185
186+ """
187+ subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName})
188+
189+ Replace each accumulator `acc` in `at` with `subset(acc, vns)`.
190+ """
191+ function subset (at:: AccumulatorTuple , vns:: AbstractVector{<:VarName} )
192+ return AccumulatorTuple (map (Base. Fix2 (subset, vns), at. nt))
193+ end
194+
195+ """
196+ _joint_keys(nt1::NamedTuple, nt2::NamedTuple)
197+
198+ A helper function that returns three tuples of keys given two `NamedTuple`s:
199+ The keys only in `nt1`, only in `nt2`, and in both, and in that order.
200+
201+ Implemented as a generated function to enable constant propagation of the result in `merge`.
202+ """
203+ @generated function _joint_keys (
204+ nt1:: NamedTuple{names1} , nt2:: NamedTuple{names2}
205+ ) where {names1,names2}
206+ only_in_nt1 = tuple (setdiff (names1, names2)... )
207+ only_in_nt2 = tuple (setdiff (names2, names1)... )
208+ in_both = tuple (intersect (names1, names2)... )
209+ return :($ only_in_nt1, $ only_in_nt2, $ in_both)
210+ end
211+
212+ """
213+ merge(at1::AccumulatorTuple, at2::AccumulatorTuple)
214+
215+ Merge two `AccumulatorTuple`s.
216+
217+ For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two
218+ accumulators themselves. Other accumulators are copied.
219+ """
220+ function Base. merge (at1:: AccumulatorTuple , at2:: AccumulatorTuple )
221+ keys_in_at1, keys_in_at2, keys_in_both = _joint_keys (at1. nt, at2. nt)
222+ accs_in_at1 = (getfield (at1. nt, key) for key in keys_in_at1)
223+ accs_in_at2 = (getfield (at2. nt, key) for key in keys_in_at2)
224+ accs_in_both = (
225+ merge (getfield (at1. nt, key), getfield (at2. nt, key)) for key in keys_in_both
226+ )
227+ return AccumulatorTuple (accs_in_at1... , accs_in_both... , accs_in_at2... )
228+ end
229+
161230"""
162231 setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator)
163232
0 commit comments