diff --git a/docs/src/api/solver_errors.md b/docs/src/api/solver_errors.md index 7ce61cdb..3e59bc74 100644 --- a/docs/src/api/solver_errors.md +++ b/docs/src/api/solver_errors.md @@ -16,6 +16,9 @@ FullResidual FullResidualRecipe +LSgradient + +LSgradientRecipe ``` ## Exported Functions diff --git a/src/RLinearAlgebra.jl b/src/RLinearAlgebra.jl index c8706bc1..4afee156 100644 --- a/src/RLinearAlgebra.jl +++ b/src/RLinearAlgebra.jl @@ -56,6 +56,7 @@ export QRSolver, QRSolverRecipe export SolverError, SolverErrorRecipe export complete_error, compute_error export FullResidual, FullResidualRecipe +export LSgradient, LSgradientRecipe # Export ApproximatorError types and functions export ApproximatorError, ApproximatorErrorRecipe diff --git a/src/Solvers/ErrorMethods.jl b/src/Solvers/ErrorMethods.jl index 93506d21..c5d45f59 100644 --- a/src/Solvers/ErrorMethods.jl +++ b/src/Solvers/ErrorMethods.jl @@ -116,3 +116,4 @@ end # Include error method files include("ErrorMethods/full_residual.jl") +include("ErrorMethods/LSgradient.jl") \ No newline at end of file diff --git a/src/Solvers/ErrorMethods/LSgradient.jl b/src/Solvers/ErrorMethods/LSgradient.jl new file mode 100644 index 00000000..c2e05470 --- /dev/null +++ b/src/Solvers/ErrorMethods/LSgradient.jl @@ -0,0 +1,41 @@ +""" + LSgradient <: SolverError + +A `SolverError` structure for computing the least-squares gradient, + ``\\nabla f(x) = A' (A x - b)`` + +# Fields +- None +""" +struct LSgradient <: SolverError end + +""" + LSgradientRecipe <: SolverErrorRecipe +A `SolverErrorRecipe` structure for computing the gradient of least-squares objective. + +# Fields +- `gradient::AbstractVector`, `A'r`. +""" +mutable struct LSgradientRecipe{V<:AbstractVector} <: SolverErrorRecipe + gradient::V +end + +function complete_error( + error::LSgradient, + solver::Solver, + A::AbstractMatrix, + b::AbstractVector +) + gradient = zeros(size(A,2)) + return LSgradientRecipe{typeof(b)}(gradient) +end + +function compute_error( + error::LSgradientRecipe, + solver::SolverRecipe, + A::AbstractMatrix, + b::AbstractVector +)::Float64 + mul!(error.gradient, A', solver.residual_vec, 1.0, 0.0) # grad = A'r + return norm(error.gradient) +end \ No newline at end of file diff --git a/test/Solvers/ErrorMethods/LSgradient.jl b/test/Solvers/ErrorMethods/LSgradient.jl new file mode 100644 index 00000000..fd371275 --- /dev/null +++ b/test/Solvers/ErrorMethods/LSgradient.jl @@ -0,0 +1,84 @@ +module gradient_error +using Test, RLinearAlgebra, Random +import LinearAlgebra: mul!, norm +using ..FieldTest +using ..ApproxTol +Random.seed!(1232) + +mutable struct TestSolver <: Solver end + +mutable struct TestSolverRecipe <: SolverRecipe + residual_vec::AbstractVector +end + +@testset "LS Gradient" begin + @testset "LS Gradient: SolverError" begin + # Verify Supertype + @test supertype(LSgradient) == SolverError + + # Verify fieldnames and types + @test fieldnames(LSgradient) == () + @test fieldtypes(LSgradient) == () + # Verify the internal constructor + + end + + @testset "LS Gradient: SolverErrorRecipe" begin + # Verify Supertype + @test supertype(LSgradientRecipe) == SolverErrorRecipe + + # Verify fieldnames and types + @test fieldnames(LSgradientRecipe) == (:gradient,) + @test fieldtypes(LSgradientRecipe) == (AbstractVector,) + end + + @testset "Residual: Complete error" begin + for type in [Float32, Float64, ComplexF32, ComplexF64] + let n_rows = 4, + n_cols = 6, + A = rand(type, n_rows, n_cols), + b = rand(type, n_rows), + x = rand(type, n_cols), + r = A*x - b, + solver_rec = TestSolverRecipe(r), + error_rec = complete_error(LSgradient(), TestSolver(), A, b) + + # Test the type + @test typeof(error_rec) == LSgradientRecipe{typeof(b)} + # Test type of residual vector + @test eltype(error_rec.gradient) == type + # Test residual vector to be all zeros + @test error_rec.gradient == zeros(type, n_cols) + end + + end + + end + + @testset "Residual: Compute Error" begin + for type in [Float32, Float64] + let n_rows = 4, + n_cols = 6, + A = rand(type, n_rows, n_cols), + b = rand(type, n_rows), + x = rand(type, n_cols), + r = A*x - b, + solver_rec = TestSolverRecipe(r), + solver = TestSolver(), + error_rec = complete_error(LSgradient(), TestSolver(), A, b) + + # compute the error value + err_val = compute_error(error_rec, solver_rec, A, b) + # compute the gradient + res = A' * r + # compute norm squared of residual + @test norm(res) ≈ err_val + end + + end + + end + +end + +end \ No newline at end of file