@@ -34,60 +34,14 @@ macro functor(args...)
3434 functorm (args... )
3535end
3636
37- function makeflexiblefunctor (m:: Module , T, pfield)
38- pfield = QuoteNode (pfield)
39- @eval m begin
40- function $Functors. functor (:: Type{<:$T} , x)
41- pfields = getproperty (x, $ pfield)
42- function re (y)
43- all_args = map (fn -> getproperty (fn in pfields ? y : x, fn), fieldnames ($ T))
44- return $ T (all_args... )
45- end
46- func = NamedTuple {pfields} (map (p -> getproperty (x, p), pfields))
47- return func, re
48- end
49-
50- end
51-
52- end
53-
54- function flexiblefunctorm (T, pfield = :params )
55- pfield isa Symbol || error (" @flexiblefunctor T param_field" )
56- pfield = QuoteNode (pfield)
57- :(makeflexiblefunctor (@__MODULE__ , $ (esc (T)), $ (esc (pfield))))
58- end
59-
60- macro flexiblefunctor (args... )
61- flexiblefunctorm (args... )
62- end
63-
6437isleaf (x) = children (x) === ()
6538
6639children (x) = functor (x)[1 ]
6740
68- function functor_tuple (f, x:: Tuple , dx:: Tuple )
69- map (x, dx) do x, x̄
70- _default_walk (f, x, x̄)
71- end
72- end
73- functor_tuple (f, x, dx) = f (x, dx)
74- functor_tuple (f, x, :: Nothing ) = x
75-
76- # @functor Chain
77- # Chain -> func = (layers = (Dense,Dense),), gs -> (layers...)
78- function _default_walk (f, x, dx)
79- func, re = functor (x)
80- map (func, dx) do x, x̄
81- # functor_tuple(f, x, x̄)
82- f (x, x̄)
83- end |> re
84- end
85-
8641function _default_walk (f, x)
8742 func, re = functor (x)
8843 re (map (f, func))
8944end
90- _default_walk (f, :: Nothing , :: Nothing ) = nothing
9145
9246function fmap (f, x; exclude = isleaf, walk = _default_walk, cache = IdDict ())
9347 haskey (cache, x) && return cache[x]
@@ -97,6 +51,10 @@ function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
9751 return y
9852end
9953
54+ # ##
55+ # ## Extras
56+ # ##
57+
10058fmapstructure (f, x; kwargs... ) = fmap (f, x; walk = (f, x) -> map (f, children (x)), kwargs... )
10159
10260function fcollect (x; output = [], cache = Base. IdSet (), exclude = v -> false )
@@ -112,10 +70,59 @@ function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
11270 return output
11371end
11472
115- # Allow gradients and other constructs that match the structure of the functor
116- # to allow for `map` style computations and return a modified version of the struct.
117- # This way we can use `fmap` to update the params with their gradients
73+ # ##
74+ # ## Vararg forms
75+ # ##
76+
11877function fmap (f, x, dx... ; cache = IdDict ())
11978 haskey (cache, x) && return cache[x]
12079 cache[x] = isleaf (x) ? f (x, dx... ) : _default_walk ((x... ) -> fmap (f, x... , cache = cache), x, dx... )
12180end
81+
82+ function functor_tuple (f, x:: Tuple , dx:: Tuple )
83+ map (x, dx) do x, x̄
84+ _default_walk (f, x, x̄)
85+ end
86+ end
87+ functor_tuple (f, x, dx) = f (x, dx)
88+ functor_tuple (f, x, :: Nothing ) = x
89+
90+ function _default_walk (f, x, dx)
91+ func, re = functor (x)
92+ map (func, dx) do x, x̄
93+ # functor_tuple(f, x, x̄)
94+ f (x, x̄)
95+ end |> re
96+ end
97+ _default_walk (f, :: Nothing , :: Nothing ) = nothing
98+
99+ # ##
100+ # ## FlexibleFunctors.jl
101+ # ##
102+
103+ function makeflexiblefunctor (m:: Module , T, pfield)
104+ pfield = QuoteNode (pfield)
105+ @eval m begin
106+ function $Functors. functor (:: Type{<:$T} , x)
107+ pfields = getproperty (x, $ pfield)
108+ function re (y)
109+ all_args = map (fn -> getproperty (fn in pfields ? y : x, fn), fieldnames ($ T))
110+ return $ T (all_args... )
111+ end
112+ func = NamedTuple {pfields} (map (p -> getproperty (x, p), pfields))
113+ return func, re
114+ end
115+
116+ end
117+
118+ end
119+
120+ function flexiblefunctorm (T, pfield = :params )
121+ pfield isa Symbol || error (" @flexiblefunctor T param_field" )
122+ pfield = QuoteNode (pfield)
123+ :(makeflexiblefunctor (@__MODULE__ , $ (esc (T)), $ (esc (pfield))))
124+ end
125+
126+ macro flexiblefunctor (args... )
127+ flexiblefunctorm (args... )
128+ end
0 commit comments