@@ -382,57 +382,112 @@ end
382382 # have fantastic support for this stuff at the minute.
383383 # also we might be missing some overloads for different tangent-types in the rules
384384 @testset " cholesky" begin
385- @testset " Real" begin
386- test_rrule (cholesky, 0.8 )
385+ @testset " Number" begin
386+ @testset " uplo=$uplo " for uplo in (:U , :L )
387+ test_rrule (cholesky, 0.8 , uplo)
388+ test_rrule (cholesky, - 0.3 , uplo)
389+ test_rrule (cholesky, 0.23 + 0im , uplo)
390+ test_rrule (cholesky, 0.78 + 0.5im , uplo)
391+ test_rrule (cholesky, - 0.34 + 0.1im , uplo)
392+ end
387393 end
388- @testset " Diagonal{<:Real}" begin
389- D = Diagonal (rand (5 ) .+ 0.1 )
390- C = cholesky (D)
391- test_rrule (
392- cholesky, D ⊢ Diagonal (randn (5 )), Val (false );
393- output_tangent= Tangent {typeof(C)} (factors= Diagonal (randn (5 )))
394- )
394+
395+ @testset " Diagonal" begin
396+ @testset " Diagonal{<:Real}" begin
397+ test_rrule (cholesky, Diagonal ([0.3 , 0.2 , 0.5 , 0.6 , 0.9 ]), Val (false ))
398+ end
399+ @testset " Diagonal{<:Complex}" begin
400+ # finite differences in general will produce matrices with non-real
401+ # diagonals, which cause factorization to fail. If we turn off the check and
402+ # ensure the cotangent is real, then test_rrule still works.
403+ D = Diagonal ([0.3 + 0im , 0.2 , 0.5 , 0.6 , 0.9 ])
404+ C = cholesky (D)
405+ test_rrule (
406+ cholesky, D, Val (false );
407+ output_tangent= Tangent {typeof(C)} (factors= complex (randn (5 , 5 ))),
408+ fkwargs= (; check= false ),
409+ )
410+ end
411+ @testset " check has correct default and passed to primal" begin
412+ @test_throws Exception rrule (cholesky, Diagonal (- rand (5 )), Val (false ))
413+ rrule (cholesky, Diagonal (- rand (5 )), Val (false ); check= false )
414+ end
415+ @testset " failed factorization" begin
416+ A = Diagonal (vcat (rand (4 ), - rand (4 ), rand (4 )))
417+ test_rrule (cholesky, A, Val (false ); fkwargs= (; check= false ))
418+ end
395419 end
396420
397- X = generate_well_conditioned_matrix (10 )
398- V = generate_well_conditioned_matrix (10 )
399- F, dX_pullback = rrule (cholesky, X, Val (false ))
400- F_1arg, dX_pullback_1arg = rrule (cholesky, X) # to test not passing the Val(false)
401- @test F == F_1arg
402- @testset " uplo=$p " for p in [:U , :L ]
403- Y, dF_pullback = rrule (getproperty, F, p)
404- Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn (size (Y)))
405- (dself, dF, dp) = dF_pullback (Ȳ)
406- @test dself === NoTangent ()
407- @test dp === NoTangent ()
408-
409- # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
410- # machinery from FiniteDifferences because that isn't set up to respect
411- # necessary special properties of the input. In the case of the Cholesky
412- # factorization, we need the input to be Hermitian.
413- ΔF = unthunk (dF)
414- _, dX, darg2 = dX_pullback (ΔF)
415- _, dX_1arg = dX_pullback_1arg (ΔF)
416- @test dX == dX_1arg
417- @test darg2 === NoTangent ()
418- X̄_ad = dot (unthunk (dX), V)
419- X̄_fd = central_fdm (5 , 1 )(0.000_001 ) do ε
420- dot (Ȳ, getproperty (cholesky (X .+ ε .* V), p))
421+ @testset " StridedMatrix" begin
422+ @testset " Matrix{$T }" for T in (Float64, ComplexF64)
423+ X = generate_well_conditioned_matrix (T, 10 )
424+ V = generate_well_conditioned_matrix (T, 10 )
425+ F, dX_pullback = rrule (cholesky, X, Val (false ))
426+ @testset " uplo=$p , cotangent eltype=$T " for p in [:U , :L ], S in unique ([T, complex (T)])
427+ Y, dF_pullback = rrule (getproperty, F, p)
428+ Ȳ = randn (S, size (Y))
429+ (dself, dF, dp) = dF_pullback (Ȳ)
430+ @test dself === NoTangent ()
431+ @test dp === NoTangent ()
432+
433+ # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
434+ # machinery from FiniteDifferences because that isn't set up to respect
435+ # necessary special properties of the input. In the case of the Cholesky
436+ # factorization, we need the input to be Hermitian.
437+ ΔF = unthunk (dF)
438+ _, dX, darg2 = dX_pullback (ΔF)
439+ @test darg2 === NoTangent ()
440+ X̄_ad = real (dot (unthunk (dX), V))
441+ X̄_fd = central_fdm (5 , 1 )(0.000_0001 ) do ε
442+ real (dot (Ȳ, getproperty (cholesky (X .+ ε .* V), p)))
443+ end
444+ @test X̄_ad ≈ X̄_fd rtol= 1e-4
445+ end
446+ end
447+ @testset " check has correct default and passed to primal" begin
448+ # this will almost certainly be a non-PD matrix
449+ X = Matrix (Symmetric (randn (10 , 10 )))
450+ @test_throws Exception rrule (cholesky, X, Val (false ))
451+ rrule (cholesky, X, Val (false ); check= false ) # just check it doesn't throw
421452 end
422- @test X̄_ad ≈ X̄_fd rtol= 1e-4
423453 end
424454
425455 # Ensure that cotangents of cholesky(::StridedMatrix) and
426456 # (cholesky ∘ Symmetric)(::StridedMatrix) are equal.
427457 @testset " Symmetric" begin
458+ X = generate_well_conditioned_matrix (10 )
459+ F, dX_pullback = rrule (cholesky, X, Val (false ))
460+
428461 X_symmetric, sym_back = rrule (Symmetric, X, :U )
429462 C, chol_back_sym = rrule (cholesky, X_symmetric, Val (false ))
430463
431- Δ = Tangent {typeof(C)} ((U = UpperTriangular ( randn (size (X) ))))
464+ Δ = Tangent {typeof(C)} ((factors = randn (size (X))))
432465 ΔX_symmetric = chol_back_sym (Δ)[2 ]
433466 @test sym_back (ΔX_symmetric)[2 ] ≈ dX_pullback (Δ)[2 ]
434467 end
435468
469+ # Ensure that cotangents of cholesky(::StridedMatrix) and
470+ # (cholesky ∘ Hermitian)(::StridedMatrix) are equal.
471+ @testset " Hermitian" begin
472+ @testset " Hermitian{$T }" for T in (Float64, ComplexF64)
473+ X = generate_well_conditioned_matrix (T, 10 )
474+ F, dX_pullback = rrule (cholesky, X, Val (false ))
475+
476+ X_hermitian, herm_back = rrule (Hermitian, X, :U )
477+ C, chol_back_herm = rrule (cholesky, X_hermitian, Val (false ))
478+
479+ Δ = Tangent {typeof(C)} ((factors= randn (T, size (X))))
480+ ΔX_hermitian = chol_back_herm (Δ)[2 ]
481+ @test herm_back (ΔX_hermitian)[2 ] ≈ dX_pullback (Δ)[2 ]
482+ end
483+ @testset " check has correct default and passed to primal" begin
484+ # this will almost certainly be a non-PD matrix
485+ X = Hermitian (randn (10 , 10 ))
486+ @test_throws Exception rrule (cholesky, X, Val (false ))
487+ rrule (cholesky, X, Val (false ); check= false )
488+ end
489+ end
490+
436491 @testset " det and logdet (uplo=$p )" for p in (:U , :L )
437492 @testset " $op " for op in (det, logdet)
438493 @testset " $T " for T in (Float64, ComplexF64)
0 commit comments