File tree Expand file tree Collapse file tree 3 files changed +30
-5
lines changed
Expand file tree Collapse file tree 3 files changed +30
-5
lines changed Original file line number Diff line number Diff line change @@ -181,12 +181,13 @@ julia> m.bias
181181"""
182182cpu (x) = fmap (x -> adapt (FluxCPUAdaptor (), x), x, exclude = _isleaf)
183183
184- _isbitsarray (:: AbstractArray{<:Number} ) = true
185- _isbitsarray (:: AbstractArray{T} ) where T = isbitstype (T)
186- _isbitsarray (x) = false
184+ _isleaf (x) = Functors. isleaf (x)
185+
186+ _isleaf (:: AbstractArray{<:Number} ) = true
187+ _isleaf (:: AbstractArray{T} ) where T = isbitstype (T)
188+ _isleaf (:: Union{Transpose, Adjoint, PermutedDimsArray} ) = false
187189
188190_isleaf (:: AbstractRNG ) = true
189- _isleaf (x) = _isbitsarray (x) || Functors. isleaf (x)
190191
191192# the order below is important
192193const GPU_BACKENDS = (" CUDA" , " AMDGPU" , " Metal" , " CPU" )
Original file line number Diff line number Diff line change 109109 # This test should really not go through indirections and pull out Fills for efficiency
110110 # but we forcefully materialise. TODO : remove materialising CuArray here
111111 @test gradient (x -> sum (cpu (x)), ca)[1 ] isa CuArray # This involves FillArray, which should be GPU compatible
112- @test gradient (x -> sum (cpu (x)), ca' )[1 ] isa CuArray
112+ @test gradient (x -> sum (cpu (x)), ca' )[1 ] isa Adjoint{Float32, <: CuArray }
113113
114114 # Even more trivial: no movement
115115 @test gradient (x -> sum (abs, cpu (x)), a)[1 ] isa Matrix
Original file line number Diff line number Diff line change 567567 @test length (Flux. params (oneadj)) == 1 # needs Functors@0.3
568568
569569 @test Flux. destructure (simple)[1 ] == Flux. destructure (oneadj)[1 ] == [1 , 3 , 2 , 4 ]
570+
571+ @testset " issue 2432" begin
572+ x = rand (1 )
573+ m = (; a = x, b = x' )
574+ count = Ref (0 )
575+ mcopy = fmap (m; exclude = Flux. _isleaf) do x
576+ count[] += 1
577+ return copy (x)
578+ end
579+ @test count[] == 1
580+ @test mcopy. a === mcopy. b'
581+
582+ struct BitsType
583+ x:: Int32
584+ y:: Float64
585+ end
586+
587+ for x in [1.0 , ' a' , BitsType (1 , 2.0 )]
588+ @test Flux. _isleaf ([x])
589+ @test ! Flux. _isleaf ([x]' )
590+ @test ! Flux. _isleaf (transpose ([x]))
591+ @test ! Flux. _isleaf (PermutedDimsArray ([x;;], (1 , 2 )))
592+ end
593+ end
570594end
571595
572596@testset " Various destructure bugs" begin
You can’t perform that action at this time.
0 commit comments