diff --git a/src/Compressors/sparse_sign.jl b/src/Compressors/sparse_sign.jl index 46568abe..a306872b 100644 --- a/src/Compressors/sparse_sign.jl +++ b/src/Compressors/sparse_sign.jl @@ -97,6 +97,35 @@ function SparseSign(; return SparseSign(cardinality, compression_dim, nnz, type) end +# Reload property +function setproperty!(obj::SparseSign, sym::Symbol, val) + if sym === :compression_dim + new_dim = val + if new_dim <= 0 + throw(ArgumentError("Field `compression_dim` must be positive.")) + elseif obj.nnz > new_dim + throw( + ArgumentError("New `compression_dim`, $new_dim, must be greater than \ + or equal to current `nnz`, $(obj.nnz)." + ) + ) + end + elseif sym === :nnz + new_nnz = val + if new_nnz <= 0 + throw(ArgumentError("Field `nnz` must be positive.")) + elseif new_nnz > obj.compression_dim + throw( + ArgumentError("New `nnz`, $new_nnz, must be less than or equal to \ + current `compression_dim`, $(obj.compression_dim)." + ) + ) + end + end + + return setfield!(obj, sym, val) +end + """ sparse_idx_update!( values::Vector{Int64}, diff --git a/src/RLinearAlgebra.jl b/src/RLinearAlgebra.jl index 4addd7c8..39d7f5d7 100644 --- a/src/RLinearAlgebra.jl +++ b/src/RLinearAlgebra.jl @@ -1,6 +1,6 @@ module RLinearAlgebra import Base.:* -import Base: transpose, adjoint +import Base: transpose, adjoint, setproperty! import LinearAlgebra: Adjoint, axpby!, dot, I, ldiv!, lmul!, lq!, lq, LQ, lu! import LinearAlgebra: mul!, norm, qr!, svd import StatsBase: sample, sample!, ProbabilityWeights, wsample! diff --git a/test/Compressors/sparse_sign.jl b/test/Compressors/sparse_sign.jl index d1802a18..83d619d7 100644 --- a/test/Compressors/sparse_sign.jl +++ b/test/Compressors/sparse_sign.jl @@ -50,6 +50,34 @@ Random.seed!(2131) end + @testset "Verify field modification" begin + compressor = SparseSign(cardinality=Left(), compression_dim=5, nnz=3, type=Float64) + + # Test compression_dim + @test_throws ArgumentError compressor.compression_dim = 0 + @test_throws ArgumentError compressor.compression_dim = -1 + @test_throws ArgumentError compressor.compression_dim = 2 + @test_throws TypeError compressor.compression_dim = 5.5 + + # Test nnz + @test_throws ArgumentError compressor.nnz = 0 + @test_throws ArgumentError compressor.nnz = -1 + @test_throws ArgumentError compressor.nnz = 6 + @test_throws TypeError compressor.nnz = 2.5 + + # Test correct assignments + compressor.compression_dim = 10 + @test compressor.compression_dim == 10 + compressor.nnz = 8 + @test compressor.nnz == 8 + + # Test no checking assignments + compressor.cardinality = Right() + @test typeof(compressor.cardinality) == Right + compressor.type = Float32 + @test compressor.type == Float32 + end + @testset "Sparse Sign: CompressorRecipe" begin @test_compressor SparseSignRecipe @test fieldnames(SparseSignRecipe) ==