-
Notifications
You must be signed in to change notification settings - Fork 2
Solver: Iterative Hessian Sketch #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8e18d8b
082eb34
7be38a3
009c6b5
db48731
c48420f
4bb21ee
7cc79b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,9 @@ SolverRecipe | |
|
|
||
| ## Solver Structures | ||
| ```@docs | ||
| IHS | ||
|
|
||
| IHSRecipe | ||
| ``` | ||
|
|
||
| ## Exported Functions | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,234 @@ | ||||||
| """ | ||||||
| IHS <: Solver | ||||||
| An implementation of the Iterative Hessian Sketch solver for solving over determined | ||||||
| least squares problems (@cite)[pilanci2014iterative]. | ||||||
| # Mathematical Description | ||||||
| Let ``A \\in \\mathbb{R}^{m \\times n}`` and consider the least square problem ``\\min_x | ||||||
| \\|Ax - b \\|_2^2``. If we let ``S \\in \\mathbb{R}^{s \\times m}`` be a compression matrix, then | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps add a comment |
||||||
| Iterative Hessian Sketch iteratively finds a solution to this problem | ||||||
| by repeatedly updating ``x_{k+1} = x_k + \\alpha u_k``where ``u_k`` is the solution to the | ||||||
| convex optimization problem, | ||||||
| ``u_k = \\min_u \\{\\|S_k Au\\|_2^2 - \\langle A, b - Ax_k \\rangle \\}.`` This method | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually, we write |
||||||
| has been to shown to converge geometrically at a rate ``\\rho \\in (0, 1/2]``, typically the | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there should be a period before |
||||||
| required compression dimension needs to be 4-8 times the size of n for the algorithm to | ||||||
| perform successfully. | ||||||
| # Fields | ||||||
| - `alpha::Float64`, a step size parameter. | ||||||
| - `compressor::Compressor`, a technique for forming the compressed linear system. | ||||||
| - `log::Logger`, a technique for logging the progress of the solver. | ||||||
| - `error::SolverError`, a method for estimating the progress of the solver. | ||||||
| # Constructor | ||||||
| function IHS(; | ||||||
| compressor::Compressor = SparseSign(cardinality = Left()), | ||||||
| log::Logger = BasicLogger(), | ||||||
| error::SolverError = FullResidual(), | ||||||
| alpha::Float64 = 1.0 | ||||||
| ) | ||||||
| ## Keywords | ||||||
| - `compressor::Compressor`, a technique for forming the compressed linear system. | ||||||
| - `log::Logger`, a technique for logging the progress of the solver. | ||||||
| - `error::SolverError', a method for estimating the progress of the solver. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The single quote should have been whatever this character is: ` |
||||||
| - `alpha::Float64`, a step size parameter. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Throw in a warning or an information admonition with some guidance about the value of |
||||||
| # Returns | ||||||
| - A `IHS` object. | ||||||
| """ | ||||||
| mutable struct IHS <: Solver | ||||||
| alpha::Float64 | ||||||
| log::Logger | ||||||
| compressor::Compressor | ||||||
| error::SolverError | ||||||
| function IHS(alpha, log, compressor, error) | ||||||
| if typeof(compressor.cardinality) != Left | ||||||
| @warn "Compressor has cardinality `Right` but IHS compresses from the `Left`." | ||||||
| end | ||||||
|
|
||||||
| new(alpha, log, compressor, error) | ||||||
| end | ||||||
|
|
||||||
| end | ||||||
|
|
||||||
| function IHS(; | ||||||
| compressor::Compressor = SparseSign(cardinality = Left()), | ||||||
| log::Logger = BasicLogger(), | ||||||
| error::SolverError = FullResidual(), | ||||||
| alpha::Float64 = 1.0 | ||||||
| ) | ||||||
| return IHS( | ||||||
| alpha, | ||||||
| log, | ||||||
| compressor, | ||||||
| error | ||||||
| ) | ||||||
| end | ||||||
|
|
||||||
| """ | ||||||
| IHSRecipe{ | ||||||
| Type<:Number, | ||||||
| LR<:LoggerRecipe, | ||||||
| CR<:CompressorRecipe, | ||||||
| ER<:ErrorRecipe, | ||||||
| M<:AbstractArray, | ||||||
| MV<:SubArray, | ||||||
| V<:AbstractVector | ||||||
| } <: SolverRecip | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| A mutable structure containing all information relevant to the Iterative Hessian Sketch | ||||||
| solver. It is formed by calling the function `complete_solver` on a `IHS` solver, which | ||||||
| includes all the user controlled parameters, the linear system `A`, and the constant | ||||||
| vector `b`. | ||||||
| # Fields | ||||||
| - `compressor::CompressorRecipe`, a technique for compressing the matrix ``A``. | ||||||
| - `logger::LoggerRecipe`, a technique for logging the progress of the solver. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| - `error::SolverErrorRecipe`, a technique for estimating the progress of the solver. | ||||||
| - `compressed_mat::AbstractMatrix`, a buffer for storing the compressed matrix. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing |
||||||
| - `mat_view::SubArray`, a container for storing a view of the compressed matrix buffer. | ||||||
| - `residual_vec::AbstractVector`, a vector that contains the residual of the linear system | ||||||
| ``Ax-b``. | ||||||
| - `gradient_vec::AbstractVector`, a vector that contains the gradient of the least squares | ||||||
| problem, ``A^\\top(b-Ax)``. | ||||||
| - `buffer_vec::AbstractVector`, a buffer vector for storing intermediate linear system solves. | ||||||
| - `solution_vec::AbstractVector`, a vector storing the current IHS solution. | ||||||
| - `R::UpperTriangular`, a container for storing the upper triangular portion of the R | ||||||
| factor from a QR factorization of `mat_view`. This is used to solve the IHS sub-problem. | ||||||
| """ | ||||||
| mutable struct IHSRecipe{ | ||||||
| Type<:Number, | ||||||
| LR<:LoggerRecipe, | ||||||
| CR<:CompressorRecipe, | ||||||
| ER<:SolverErrorRecipe, | ||||||
| M<:AbstractArray, | ||||||
| MV<:SubArray, | ||||||
| V<:AbstractVector | ||||||
| } <: SolverRecipe | ||||||
| log::LR | ||||||
| compressor::CR | ||||||
| error::ER | ||||||
| alpha::Float64 | ||||||
| compressed_mat::M | ||||||
| mat_view::MV | ||||||
| residual_vec::V | ||||||
| gradient_vec::V | ||||||
| buffer_vec::V | ||||||
| solution_vec::V | ||||||
| R::UpperTriangular{Type, M} | ||||||
| end | ||||||
|
|
||||||
| function complete_solver( | ||||||
| ingredients::IHS, | ||||||
| x::AbstractVector, | ||||||
| A::AbstractMatrix, | ||||||
| b::AbstractVector | ||||||
| ) | ||||||
| compressor = complete_compressor(ingredients.compressor, x, A, b) | ||||||
| logger = complete_logger(ingredients.log) | ||||||
| error = complete_error(ingredients.error, ingredients, A, b) | ||||||
| sample_size::Int64 = compressor.n_rows | ||||||
| rows_a, cols_a = size(A) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add check if |
||||||
| # Check that required fields are in the types | ||||||
| if !isdefined(error, :residual) | ||||||
| throw( | ||||||
| ArgumentError( | ||||||
| "ErrorRecipe $(typeof(error)) does not contain the \ | ||||||
| field 'residual' and is not valid for an IHS solver." | ||||||
| ) | ||||||
| ) | ||||||
| end | ||||||
|
|
||||||
| if !isdefined(logger, :converged) | ||||||
| throw( | ||||||
| ArgumentError( | ||||||
| "LoggerRecipe $(typeof(logger)) does not contain \ | ||||||
| the field 'converged' and is not valid for an IHS solver." | ||||||
| ) | ||||||
| ) | ||||||
| end | ||||||
|
|
||||||
| # Check that the sketch size is larger than the column dimension and return a warning | ||||||
| # otherwise | ||||||
| if cols_a > sample_size | ||||||
| throw( | ||||||
| ArgumentError( | ||||||
| "Compression dimension not larger than column dimension this will lead to \ | ||||||
| singular QR decompositions, which cannot be inverted." | ||||||
| ) | ||||||
| ) | ||||||
| end | ||||||
|
|
||||||
| compressed_mat = zeros(eltype(A), sample_size, cols_a) | ||||||
| res = zeros(eltype(A), rows_a) | ||||||
| grad = zeros(eltype(A), cols_a) | ||||||
| buffer_vec = zeros(eltype(A), cols_a) | ||||||
| solution_vec = x | ||||||
| mat_view = view(compressed_mat, 1:sample_size, :) | ||||||
| R = UpperTriangular(mat_view[1:cols_a, :]) | ||||||
|
|
||||||
| return IHSRecipe{ | ||||||
| eltype(compressed_mat), | ||||||
| typeof(logger), | ||||||
| typeof(compressor), | ||||||
| typeof(error), | ||||||
| typeof(compressed_mat), | ||||||
| typeof(mat_view), | ||||||
| typeof(buffer_vec) | ||||||
| }( | ||||||
| logger, | ||||||
| compressor, | ||||||
| error, | ||||||
| ingredients.alpha, | ||||||
| compressed_mat, | ||||||
| mat_view, | ||||||
| res, | ||||||
| grad, | ||||||
| buffer_vec, | ||||||
| solution_vec, | ||||||
| R | ||||||
| ) | ||||||
| end | ||||||
|
|
||||||
| function rsolve!(solver::IHSRecipe, x::AbstractVector, A::AbstractMatrix, b::AbstractVector) | ||||||
| reset_logger!(solver.log) | ||||||
| solver.solution_vec = x | ||||||
| err = 0.0 | ||||||
| copyto!(solver.residual_vec, b) | ||||||
| # compute the initial residual r = b - Ax | ||||||
| mul!(solver.residual_vec, A, solver.solution_vec, -1.0, 1.0) | ||||||
| for i in 1:solver.log.max_it | ||||||
| # compute the gradient A'r | ||||||
| mul!(solver.gradient_vec, A', solver.residual_vec) | ||||||
| err = compute_error(solver.error, solver, A, b) | ||||||
| update_logger!(solver.log, err, i) | ||||||
| if solver.log.converged | ||||||
| return nothing | ||||||
| end | ||||||
|
|
||||||
| # generate a new compressor | ||||||
| update_compressor!(solver.compressor, x, A, b) | ||||||
| # Based on the size of the compressor update views of the matrix | ||||||
| rows_s, cols_s = size(solver.compressor) | ||||||
| solver.mat_view = view(solver.compressed_mat, 1:rows_s, :) | ||||||
| # Compress the matrix | ||||||
| mul!(solver.mat_view, solver.compressor, A) | ||||||
| # Update the subsolver | ||||||
| # This is the only piece of allocating code | ||||||
| solver.R = UpperTriangular(qr!(solver.mat_view).R) | ||||||
| # Compute first R' solver R'R x = g | ||||||
| ldiv!(solver.buffer_vec, solver.R', solver.gradient_vec) | ||||||
| # Compute second R Solve Rx = (R')^(-1)g will be stored in gradient_vec | ||||||
| ldiv!(solver.gradient_vec, solver.R, solver.buffer_vec) | ||||||
| # update the solution | ||||||
| # solver.solution_vec = solver.solution_vec + alpha * solver.gradient_vec | ||||||
|
Comment on lines
+219
to
+225
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add "Iterative Hessian Sketch" to the wiki page (either v0.2 or v0.3, your choice). And indicate that we want to allow alternative subsolvers in the inner loop. |
||||||
| axpy!(solver.alpha, solver.gradient_vec, solver.solution_vec) | ||||||
| # compute the fast update of r = r - A * gradient_vec | ||||||
| # note: in this case gradient vec stores the update | ||||||
| mul!(solver.residual_vec, A, solver.gradient_vec, -1.0, 1.0) | ||||||
| end | ||||||
|
|
||||||
| return nothing | ||||||
|
|
||||||
| end | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this reference is listed twice. I think the JMLR version (linked in the PR description) should be the appropriate reference to use: