@@ -75,10 +75,29 @@ Equivalent to `functor(x)[1]`.
7575"""
7676children (x) = functor (x)[1 ]
7777
78+ function functor_tuple (f, x:: Tuple , dx:: Tuple )
79+ map (x, dx) do x, x̄
80+ _default_walk (f, x, x̄)
81+ end
82+ end
83+ functor_tuple (f, x, dx) = f (x, dx)
84+ functor_tuple (f, x, :: Nothing ) = x
85+
86+ # @functor Chain
87+ # Chain -> func = (layers = (Dense,Dense),), gs -> (layers...)
88+ function _default_walk (f, x, dx)
89+ func, re = functor (x)
90+ map (func, dx) do x, x̄
91+ # functor_tuple(f, x, x̄)
92+ f (x, x̄)
93+ end |> re
94+ end
95+
7896function _default_walk (f, x)
7997 func, re = functor (x)
8098 re (map (f, func))
8199end
100+ _default_walk (f, :: Nothing , :: Nothing ) = nothing
82101
83102"""
84103 fmap(f, x; exclude = isleaf, walk = Functors._default_walk)
@@ -205,3 +224,11 @@ function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
205224 end
206225 return output
207226end
227+
228+ # Allow gradients and other constructs that match the structure of the functor
229+ # to allow for `map` style computations and return a modified version of the struct.
230+ # This way we can use `fmap` to update the params with their gradients
231+ function fmap (f, x, dx... ; cache = IdDict ())
232+ haskey (cache, x) && return cache[x]
233+ cache[x] = isleaf (x) ? f (x, dx... ) : _default_walk ((x... ) -> fmap (f, x... , cache = cache), x, dx... )
234+ end
0 commit comments