@@ -2,6 +2,7 @@ using Optimisers
22using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
33using LinearAlgebra, Statistics, Test, Random
44using Optimisers: @. ., @lazy
5+ using Base. Broadcast: broadcasted, instantiate, Broadcasted
56
67Random. seed! (1 )
78
@@ -89,7 +90,8 @@ y2z(x) = x
8990 _, m2 = Optimisers. update (s2, m, (α = ([0.1 ], nothing ), γ = [1 ,10 ,100 ],))
9091 @test only (m. α[1 ] .- m2. α[1 ]) ≈ 0.1
9192 @test norm (m. γ .- m2. γ) ≈ 10
92- @test_throws DomainError Optimisers. update (s2, m, (α = [0.1 ], γ = [1 ,10 ,NaN ],))
93+ # This error is thrown by apply! due to NaN input.
94+ @test_throws DomainError Optimisers. update (s2, m, (α = ([0.1 ], nothing ), γ = [1 ,10 ,NaN ],))
9395
9496 s3 = Optimisers. setup (ClipNorm (5 , 1 ; throw= false ), m)
9597 _, m3 = Optimisers. update (s3, m, (α = ([0.1 ], nothing ), γ = [1 ,10 ,100 ],))
@@ -506,6 +508,19 @@ y2z(x) = x
506508 y = Optimisers. subtract! (x, nothing )
507509 @test y === x
508510 end
511+
512+ @testset " _norm(dx, p) works" begin
513+ bc = instantiate (broadcasted (+ , randn (Float32, 10 ), randn (Float32, 10 )' ));
514+ arr = collect (bc)
515+ bc2 = instantiate (broadcasted (+ , [1 , 0 , - 3 , 4 ], 0 ))
516+ arr2 = collect (bc2)
517+ for p in (- Inf , - 3 , - 1 , 0 , 0.5 , 1 , 1.5 , 2 , 3f0 , Inf32 )
518+ @test Optimisers. _norm (bc, p) ≈ norm (arr, p)
519+ @test Optimisers. _norm (bc, p) isa Float32
520+ @test Optimisers. _norm (bc2, p) ≈ norm (arr2, p)
521+ @test Optimisers. _norm (bc2, p) isa Float64
522+ end
523+ end
509524 end
510525 @testset verbose= true " Destructure" begin
511526 include (" destructure.jl" )
0 commit comments