1- struct Foo
2- x
3- y
4- end
1+
2+ using Functors : functor
3+
4+ struct Foo; x; y; end
55@functor Foo
66
7- struct Bar
8- x
9- end
7+ struct Bar; x; end
108@functor Bar
119
12- struct Baz
13- x
14- y
15- z
16- end
17- @functor Baz (y,)
10+ struct OneChild3; x; y; z; end
11+ @functor OneChild3 (y,)
1812
19- struct NoChildren
20- x
21- y
22- end
13+ struct NoChildren2; x; y; end
2314
2415@static if VERSION >= v " 1.6"
2516 @testset " ComposedFunction" begin
3122 end
3223end
3324
25+ # ##
26+ # ## Basic functionality
27+ # ##
28+
3429@testset " Nested" begin
3530 model = Bar (Foo (1 , [1 , 2 , 3 ]))
3631
5348 @test fmap (f, x; exclude = x -> x isa AbstractArray) == x
5449end
5550
51+ @testset " Property list" begin
52+ model = OneChild3 (1 , 2 , 3 )
53+ model′ = fmap (x -> 2 x, model)
54+
55+ @test (model′. x, model′. y, model′. z) == (1 , 4 , 3 )
56+ end
57+
58+ @testset " cache" begin
59+ shared = [1 ,2 ,3 ]
60+ m1 = Foo (shared, Foo ([1 ,2 ,3 ], Foo (shared, [1 ,2 ,3 ])))
61+ m1f = fmap (float, m1)
62+ @test m1f. x === m1f. y. y. x
63+ @test m1f. x != = m1f. y. x
64+ m1p = fmapstructure (identity, m1; prune = nothing )
65+ @test m1p == (x = [1 , 2 , 3 ], y = (x = [1 , 2 , 3 ], y = (x = nothing , y = [1 , 2 , 3 ])))
66+
67+ # A non-leaf node can also be repeated:
68+ m2 = Foo (Foo (shared, 4 ), Foo (shared, 4 ))
69+ @test m2. x === m2. y
70+ m2f = fmap (float, m2)
71+ @test m2f. x. x === m2f. y. x
72+ m2p = fmapstructure (identity, m2; prune = Bar (0 ))
73+ @test m2p == (x = (x = [1 , 2 , 3 ], y = 4 ), y = Bar (0 ))
74+
75+ # Repeated isbits types should not automatically be regarded as shared:
76+ m3 = Foo (Foo (shared, 1 : 3 ), Foo (1 : 3 , shared))
77+ m3p = fmapstructure (identity, m3; prune = 0 )
78+ @test m3p. y. y == 0
79+ @test_broken m3p. y. x == 1 : 3
80+ end
81+
82+ @testset " functor(typeof(x), y) from @functor" begin
83+ nt1, re1 = functor (Foo, (x= 1 , y= 2 , z= 3 ))
84+ @test nt1 == (x = 1 , y = 2 )
85+ @test re1 ((x = 10 , y = 20 )) == Foo (10 , 20 )
86+ re1 ((y = 22 , x = 11 )) # gives Foo(22, 11), is that a bug?
87+
88+ nt2, re2 = functor (Foo, (z= 33 , x= 1 , y= 2 ))
89+ @test nt2 == (x = 1 , y = 2 )
90+ @test re2 ((x = 10 , y = 20 )) == Foo (10 , 20 )
91+
92+ @test_throws Exception functor (Foo, (z= 33 , x= 1 )) # type NamedTuple has no field y
93+
94+ nt3, re3 = functor (OneChild3, (x= 1 , y= 2 , z= 3 ))
95+ @test nt3 == (y = 2 ,)
96+ @test re3 ((y = 20 ,)) == OneChild3 (1 , 20 , 3 )
97+ re3 (22 ) # gives OneChild3(1, 22, 3), is that a bug?
98+ end
99+
100+ @testset " functor(typeof(x), y) for Base types" begin
101+ nt11, re11 = functor (NamedTuple{(:x , :y )}, (x= 1 , y= 2 , z= 3 ))
102+ @test nt11 == (x = 1 , y = 2 )
103+ @test re11 ((x = 10 , y = 20 )) == (x = 10 , y = 20 )
104+ re11 ((y = 22 , x = 11 ))
105+ re11 ((11 , 22 )) # passes right through
106+
107+ nt12, re12 = functor (NamedTuple{(:x , :y )}, (z= 33 , x= 1 , y= 2 ))
108+ @test nt12 == (x = 1 , y = 2 )
109+ @test re12 ((x = 10 , y = 20 )) == (x = 10 , y = 20 )
110+
111+ @test_throws Exception functor (NamedTuple{(:x , :y )}, (z= 33 , x= 1 ))
112+ end
113+
114+ # ##
115+ # ## Extras
116+ # ##
117+
56118@testset " Walk" begin
57119 model = Foo ((0 , Bar ([1 , 2 , 3 ])), [4 , 5 ])
58120
59121 model′ = fmapstructure (identity, model)
60122 @test model′ == (; x= (0 , (; x= [1 , 2 , 3 ])), y= [4 , 5 ])
61123end
62124
63- @testset " Property list" begin
64- model = Baz (1 , 2 , 3 )
65- model′ = fmap (x -> 2 x, model)
66-
67- @test (model′. x, model′. y, model′. z) == (1 , 4 , 3 )
68- end
69-
70125@testset " fcollect" begin
71126 m1 = [1 , 2 , 3 ]
72127 m2 = 1
78133
79134 m1 = [1 , 2 , 3 ]
80135 m2 = Bar (m1)
81- m0 = NoChildren (:a , :b )
136+ m0 = NoChildren2 (:a , :b )
82137 m3 = Foo (m2, m0)
83138 m4 = Bar (m3)
84139 @test all (fcollect (m4) .=== [m4, m3, m2, m1, m0])
89144 @test all (fcollect (m3) .=== [m3, m1, m2])
90145end
91146
147+ # ##
148+ # ## Vararg forms
149+ # ##
150+
151+ @testset " fmap(f, x, y)" begin
152+ m1 = (x = [1 ,2 ], y = 3 )
153+ n1 = (x = [4 ,5 ], y = 6 )
154+ @test fmap (+ , m1, n1) == (x = [5 , 7 ], y = 9 )
155+
156+ # Reconstruction type comes from the first argument
157+ foo1 = Foo ([7 ,8 ], 9 )
158+ @test fmap (+ , m1, foo1) == (x = [8 , 10 ], y = 12 )
159+ @test fmap (+ , foo1, n1) isa Foo
160+ @test fmap (+ , foo1, n1). x == [11 , 13 ]
161+
162+ # Mismatched trees should be an error
163+ m2 = (x = [1 ,2 ], y = (a = [3 ,4 ], b = 5 ))
164+ n2 = (x = [6 ,7 ], y = 8 )
165+ @test_throws Exception fmap (first∘ tuple, m2, n2) # ERROR: type Int64 has no field a
166+ @test_throws Exception fmap (first∘ tuple, m2, n2)
167+
168+ # The cache uses IDs from the first argument
169+ shared = [1 ,2 ,3 ]
170+ m3 = (x = shared, y = [4 ,5 ,6 ], z = shared)
171+ n3 = (x = shared, y = shared, z = [7 ,8 ,9 ])
172+ @test fmap (+ , m3, n3) == (x = [2 , 4 , 6 ], y = [5 , 7 , 9 ], z = [2 , 4 , 6 ])
173+ z3 = fmap (+ , m3, n3)
174+ @test z3. x === z3. z
175+
176+ # Pruning of duplicates:
177+ @test fmap (+ , m3, n3; prune = nothing ) == (x = [2 ,4 ,6 ], y = [5 ,7 ,9 ], z = nothing )
178+
179+ # More than two arguments:
180+ z4 = fmap (+ , m3, n3, m3, n3)
181+ @test z4 == fmap (x -> 2 x, z3)
182+ @test z4. x === z4. z
183+
184+ @test fmap (+ , foo1, m1, n1) isa Foo
185+ @static if VERSION >= v " 1.6" # fails on Julia 1.0
186+ @test fmap (.* , m1, foo1, n1) == (x = [4 * 7 , 2 * 5 * 8 ], y = 3 * 6 * 9 )
187+ end
188+ end
189+
190+ @static if VERSION >= v " 1.6" # Julia 1.0: LoadError: error compiling top-level scope: type definition not allowed inside a local scope
191+ @testset " old test update.jl" begin
192+ struct M{F,T,S}
193+ σ:: F
194+ W:: T
195+ b:: S
196+ end
197+
198+ @functor M
199+
200+ (m:: M )(x) = m. σ .(m. W * x .+ m. b)
201+
202+ m = M (identity, ones (Float32, 3 , 4 ), zeros (Float32, 3 ))
203+ x = ones (Float32, 4 , 2 )
204+ m̄, _ = gradient ((m,x) -> sum (m (x)), m, x)
205+ m̂ = Functors. fmap (m, m̄) do x, y
206+ isnothing (x) && return y
207+ isnothing (y) && return x
208+ x .- 0.1f0 .* y
209+ end
210+
211+ @test m̂. W ≈ fill (0.8f0 , size (m. W))
212+ @test m̂. b ≈ fill (- 0.2f0 , size (m. b))
213+ end
214+ end # VERSION
215+
216+ # ##
217+ # ## FlexibleFunctors.jl
218+ # ##
219+
92220struct FFoo
93221 x
94222 y
@@ -102,13 +230,13 @@ struct FBar
102230end
103231@flexiblefunctor FBar p
104232
105- struct FBaz
233+ struct FOneChild4
106234 x
107235 y
108236 z
109237 p
110238end
111- @flexiblefunctor FBaz p
239+ @flexiblefunctor FOneChild4 p
112240
113241@testset " Flexible Nested" begin
114242 model = FBar (FFoo (1 , [1 , 2 , 3 ], (:y , )), (:x ,))
132260end
133261
134262@testset " Flexible Property list" begin
135- model = FBaz (1 , 2 , 3 , (:x , :z ))
263+ model = FOneChild4 (1 , 2 , 3 , (:x , :z ))
136264 model′ = fmap (x -> 2 x, model)
137265
138266 @test (model′. x, model′. y, model′. z) == (2 , 2 , 6 )
147275 @test all (fcollect (m4, exclude = x -> x isa Array) .=== [m4, m3])
148276 @test all (fcollect (m4, exclude = x -> x isa FFoo) .=== [m4])
149277
150- m0 = NoChildren (:a , :b )
278+ m0 = NoChildren2 (:a , :b )
151279 m1 = [1 , 2 , 3 ]
152280 m2 = FBar (m1, ())
153281 m3 = FFoo (m2, m0, (:x , :y ,))
0 commit comments