@@ -7,6 +7,10 @@ functor(::Type{<:NamedTuple}, x) = x, y -> y
77functor (:: Type{<:AbstractArray} , x) = x, y -> y
88functor (:: Type{<:AbstractArray{<:Number}} , x) = (), _ -> x
99
10+ @static if VERSION >= v " 1.6"
11+ functor (:: Type{<:Base.ComposedFunction} , x) = (outer = x. outer, inner = x. inner), y -> Base. ComposedFunction (y. outer, y. inner)
12+ end
13+
1014function makefunctor (m:: Module , T, fs = fieldnames (T))
1115 yᵢ = 0
1216 escargs = map (fieldnames (T)) do f
@@ -20,15 +24,42 @@ function makefunctor(m::Module, T, fs = fieldnames(T))
2024end
2125
2226function functorm (T, fs = nothing )
23- fs == nothing || isexpr (fs, :tuple ) || error (" @functor T (a, b)" )
24- fs = fs == nothing ? [] : [:($ (map (QuoteNode, fs. args)... ),)]
27+ fs === nothing || Meta . isexpr (fs, :tuple ) || error (" @functor T (a, b)" )
28+ fs = fs === nothing ? [] : [:($ (map (QuoteNode, fs. args)... ),)]
2529 :(makefunctor (@__MODULE__ , $ (esc (T)), $ (fs... )))
2630end
2731
2832macro functor (args... )
2933 functorm (args... )
3034end
3135
36+ function makeflexiblefunctor (m:: Module , T, pfield)
37+ pfield = QuoteNode (pfield)
38+ @eval m begin
39+ function $Functors. functor (:: Type{<:$T} , x)
40+ pfields = getproperty (x, $ pfield)
41+ function re (y)
42+ all_args = map (fn -> getproperty (fn in pfields ? y : x, fn), fieldnames ($ T))
43+ return $ T (all_args... )
44+ end
45+ func = NamedTuple {pfields} (map (p -> getproperty (x, p), pfields))
46+ return func, re
47+ end
48+
49+ end
50+
51+ end
52+
53+ function flexiblefunctorm (T, pfield = :params )
54+ pfield isa Symbol || error (" @flexiblefunctor T param_field" )
55+ pfield = QuoteNode (pfield)
56+ :(makeflexiblefunctor (@__MODULE__ , $ (esc (T)), $ (esc (pfield))))
57+ end
58+
59+ macro flexiblefunctor (args... )
60+ flexiblefunctorm (args... )
61+ end
62+
3263"""
3364 isleaf(x)
3465
@@ -137,7 +168,8 @@ fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x))
137168 fcollect(x; exclude = v -> false)
138169
139170Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref)
140- and collecting the results into a flat array.
171+ and collecting the results into a flat array, ordered by a breadth-first
172+ traversal of `x`, respecting the iteration order of `children` calls.
141173
142174Doesn't recurse inside branches rooted at nodes `v`
143175for which `exclude(v) == true`.
@@ -180,13 +212,17 @@ julia> fcollect(m, exclude = v -> Functors.isleaf(v))
180212 Bar([1, 2, 3])
181213```
182214"""
183- function fcollect (x; cache = [], exclude = v -> false )
184- x in cache && return cache
185- if ! exclude (x)
186- push! (cache, x)
187- foreach (y -> fcollect (y; cache = cache, exclude = exclude), children (x))
188- end
189- return cache
215+ function fcollect (x; output = [], cache = Base. IdSet (), exclude = v -> false )
216+ # note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
217+ # (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
218+ # for the results, to preserve traversal order (important downstream!).
219+ x in cache && return output
220+ if ! exclude (x)
221+ push! (cache, x)
222+ push! (output, x)
223+ foreach (y -> fcollect (y; cache= cache, output= output, exclude= exclude), children (x))
224+ end
225+ return output
190226end
191227
192228# Allow gradients and other constructs that match the structure of the functor
0 commit comments