Skip to content

Commit 2880e6c

Browse files
authored
Add overloads for fld, cld, and div (#198)
* Add overloads for cld, fld, and div * Test new rules * Increment minor version number * Match method defined in base * Revert "Match method defined in base" This reverts commit d98f5e5.
1 parent c3c96eb commit 2880e6c

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReverseDiff"
22
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3-
version = "1.12.0"
3+
version = "1.13.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/tracked.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,16 @@ Base.floor(::Type{R}, t::TrackedReal) where {R<:Real} = floor(R, value(t))
428428
Base.ceil(t::TrackedReal) = ceil(value(t))
429429
Base.ceil(::Type{R}, t::TrackedReal) where {R<:Real} = ceil(R, value(t))
430430

431+
Base.fld(a::TrackedReal, b::TrackedReal) = fld(value(a), value(b))
432+
433+
Base.cld(a::TrackedReal, b::TrackedReal) = cld(value(a), value(b))
434+
435+
if VERSION v"1.4"
436+
Base.div(x::TrackedReal, y::TrackedReal, r::RoundingMode) = div(value(x), value(y), r)
437+
else
438+
Base.div(x::TrackedReal, y::TrackedReal) = div(value(x), value(y))
439+
end
440+
431441
Base.trunc(t::TrackedReal) = trunc(value(t))
432442
Base.trunc(::Type{R}, t::TrackedReal) where {R<:Real} = trunc(R, value(t))
433443

test/TrackedTests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,9 +703,11 @@ empty!(tp)
703703
####################
704704

705705
v_int, v_float, d = rand(Int), rand(), rand()
706+
v_float2, d2 = rand(), rand()
706707
tp = InstructionTape()
707708
tr_int = TrackedReal(v_int, d, tp)
708709
tr_float = TrackedReal(v_float, d, tp)
710+
tr_float2 = TrackedReal(v_float2, d2, tp)
709711

710712
@test hash(tr_float) === hash(v_float)
711713
@test hash(tr_float, hash(1)) === hash(v_float, hash(1))
@@ -735,6 +737,26 @@ tr_rand = rand(MersenneTwister(1), TrackedReal{Int,Float64,Nothing})
735737
@test ceil(tr_float) === ceil(v_float)
736738
@test ceil(Int, tr_float) === ceil(Int, v_float)
737739

740+
@test fld(tr_float, tr_float2) === fld(v_float, v_float2)
741+
@test fld(tr_float, v_float2) === fld(v_float, v_float2)
742+
@test fld(v_float, tr_float2) === fld(v_float, v_float2)
743+
744+
@test cld(tr_float, tr_float2) === cld(v_float, v_float2)
745+
@test cld(tr_float, v_float2) === cld(v_float, v_float2)
746+
@test cld(v_float, tr_float2) === cld(v_float, v_float2)
747+
748+
@test div(tr_float, tr_float2) === div(v_float, v_float2)
749+
@test div(v_float, tr_float2) === div(v_float, v_float2)
750+
@test div(tr_float, v_float2) === div(v_float, v_float2)
751+
752+
if VERSION v"1.4"
753+
for r in (RoundUp, RoundDown)
754+
@test div(tr_float, tr_float2, r) === div(v_float, v_float2, r)
755+
@test div(v_float, tr_float2, r) === div(v_float, v_float2, r)
756+
@test div(tr_float, v_float2, r) === div(v_float, v_float2, r)
757+
end
758+
end
759+
738760
@test trunc(tr_float) === trunc(v_float)
739761
@test trunc(Int, tr_float) === trunc(Int, v_float)
740762

0 commit comments

Comments
 (0)