Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MeasureBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const Pretty = PrettyPrinting
using ChainRulesCore
import FillArrays
using Static
using Static: StaticInteger
using FunctionChains

export ≪
Expand Down
2 changes: 1 addition & 1 deletion src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ end
@generated function _logdensity_rel(
μs::Tμ,
νs::Tν,
::Tuple{StaticInt{M},StaticInt{N}},
::Tuple{<:StaticInteger{M},<:StaticInteger{N}},
x::X,
) where {Tμ,Tν,M,N,X}
sμ = schema(Tμ)
Expand Down
2 changes: 1 addition & 1 deletion src/standard/stdmeasure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end

# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}):

_std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M()
_std_measure(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M()
_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof
_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ))

Expand Down
8 changes: 4 additions & 4 deletions src/static.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
MeasureBase.IntegerLike

Equivalent to `Union{Integer,Static.StaticInt}`.
Equivalent to `Union{Integer,Static.StaticInteger}`.
"""
const IntegerLike = Union{Integer,Static.StaticInt}
const IntegerLike = Union{Integer,Static.StaticInteger}

"""
MeasureBase.one_to(n::IntegerLike)
Expand All @@ -14,7 +14,7 @@ Returns an instance of `Base.OneTo` or `Static.SOneTo`, depending
on the type of `n`.
"""
@inline one_to(n::Integer) = Base.OneTo(n)
@inline one_to(::Static.StaticInt{N}) where {N} = Static.SOneTo{N}()
@inline one_to(::Static.StaticInteger{N}) where {N} = Static.SOneTo{N}()

_dynamic(x::Number) = dynamic(x)
_dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N)
Expand Down Expand Up @@ -49,7 +49,7 @@ Returns the length of `x` as a dynamic or static integer.
"""
maybestatic_length(x) = length(x)
maybestatic_length(x::AbstractUnitRange) = length(x)
function maybestatic_length(::Static.OptionallyStaticUnitRange{StaticInt{A},StaticInt{B}}) where {A,B}
function maybestatic_length(::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}) where {A,B}
StaticInt{B - A + 1}()
end

Expand Down
8 changes: 4 additions & 4 deletions src/transport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,14 @@ _origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback

# If both both measures have no origin:
function _transport_between_origins(ν, ::StaticInt{0}, ::StaticInt{0}, μ, x)
function _transport_between_origins(ν, ::StaticInteger{0}, ::StaticInteger{0}, μ, x)
_transport_with_intermediate(ν, _transport_intermediate(ν, μ), μ, x)
end

@generated function _transport_between_origins(
ν,
::StaticInt{n_ν},
::StaticInt{n_μ},
::StaticInteger{n_ν},
::StaticInteger{n_μ},
μ,
x,
) where {n_ν,n_μ}
Expand Down Expand Up @@ -188,7 +188,7 @@ end

@inline _transport_intermediate(ν, μ) = _transport_intermediate(getdof(ν), getdof(μ))
@inline _transport_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ
@inline _transport_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform()
@inline _transport_intermediate(::StaticInteger{1}, ::StaticInteger{1}) = StdUniform()

_call_transport_def(ν, μ, x) = transport_def(ν, μ, x)
_call_transport_def(::Any, ::Any, x::NoTransportOrigin) = x
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repeatedly until there's no change. That's what this does.
_rootmeasure(μ, static(n))
end

@generated function _rootmeasure(μ, ::StaticInt{n}) where {n}
@generated function _rootmeasure(μ, ::StaticInteger{n}) where {n}
q = quote end
foreach(1:n) do _
push!(q.args, :(μ = basemeasure(μ)))
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using MeasureBase: test_interface, test_smf
using Aqua
Aqua.test_all(MeasureBase; ambiguities = false)

include("static.jl")

# Aqua._test_ambiguities(
# Aqua.aspkgids(MeasureBase);
# exclude = [LogarithmicNumbers.Logarithmic],
Expand Down
31 changes: 31 additions & 0 deletions test/static.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Test

import MeasureBase

import Static
using Static: static
import FillArrays

@testset "static" begin
@test 2 isa MeasureBase.IntegerLike
@test static(2) isa MeasureBase.IntegerLike
@test true isa MeasureBase.IntegerLike
@test static(true) isa MeasureBase.IntegerLike

@test @inferred(MeasureBase.one_to(7)) isa Base.OneTo
@test @inferred(MeasureBase.one_to(7)) == 1:7
@test @inferred(MeasureBase.one_to(static(7))) isa Static.SOneTo
@test @inferred(MeasureBase.one_to(static(7))) == static(1):static(7)

@test @inferred(MeasureBase.fill_with(4.2, (7,))) == FillArrays.Fill(4.2, 7)
@test @inferred(MeasureBase.fill_with(4.2, (static(7),))) == FillArrays.Fill(4.2, 7)
@test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == FillArrays.Fill(4.2, 3, 7)
@test @inferred(MeasureBase.fill_with(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,))
@test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == FillArrays.Fill(4.2, (3:7,))
@test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == FillArrays.Fill(4.2, (3:7, 2:5))

@test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) isa Int
@test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) == 7
@test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) isa Static.StaticInt
@test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) == static(7)
end