Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/Compressors/sparse_sign.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,35 @@ function SparseSign(;
return SparseSign(cardinality, compression_dim, nnz, type)
end

# Reload property
function setproperty!(obj::SparseSign, sym::Symbol, val)
if sym === :compression_dim
new_dim = val
if new_dim <= 0
throw(ArgumentError("Field `compression_dim` must be positive."))
elseif obj.nnz > new_dim
throw(
ArgumentError("New `compression_dim`, $new_dim, must be greater than \
or equal to current `nnz`, $(obj.nnz)."
)
)
end
elseif sym === :nnz
new_nnz = val
if new_nnz <= 0
throw(ArgumentError("Field `nnz` must be positive."))
elseif new_nnz > obj.compression_dim
throw(
ArgumentError("New `nnz`, $new_nnz, must be less than or equal to \
current `compression_dim`, $(obj.compression_dim)."
)
)
end
end

return setfield!(obj, sym, val)
end

"""
sparse_idx_update!(
values::Vector{Int64},
Expand Down
2 changes: 1 addition & 1 deletion src/RLinearAlgebra.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module RLinearAlgebra
import Base.:*
import Base: transpose, adjoint
import Base: transpose, adjoint, setproperty!
import LinearAlgebra: Adjoint, axpby!, dot, I, ldiv!, lmul!, lq!, lq, LQ, lu!
import LinearAlgebra: mul!, norm, qr!, svd
import StatsBase: sample, sample!, ProbabilityWeights, wsample!
Expand Down
28 changes: 28 additions & 0 deletions test/Compressors/sparse_sign.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,34 @@ Random.seed!(2131)

end

@testset "Verify field modification" begin
compressor = SparseSign(cardinality=Left(), compression_dim=5, nnz=3, type=Float64)

# Test compression_dim
@test_throws ArgumentError compressor.compression_dim = 0
@test_throws ArgumentError compressor.compression_dim = -1
@test_throws ArgumentError compressor.compression_dim = 2
@test_throws TypeError compressor.compression_dim = 5.5

# Test nnz
@test_throws ArgumentError compressor.nnz = 0
@test_throws ArgumentError compressor.nnz = -1
@test_throws ArgumentError compressor.nnz = 6
@test_throws TypeError compressor.nnz = 2.5

# Test correct assignments
compressor.compression_dim = 10
@test compressor.compression_dim == 10
compressor.nnz = 8
@test compressor.nnz == 8

# Test no checking assignments
compressor.cardinality = Right()
@test typeof(compressor.cardinality) == Right
compressor.type = Float32
@test compressor.type == Float32
end

@testset "Sparse Sign: CompressorRecipe" begin
@test_compressor SparseSignRecipe
@test fieldnames(SparseSignRecipe) ==
Expand Down
Loading