Skip to content

Commit c68a8b0

Browse files
avik-palsiddharthabishnugithub-actions[bot]
authored
feat: add degree-based wrappers for TracedRNumber (#1814)
* Add degree-based wrappers for TracedRNumber * test: new trig functions * Update test/basic.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Sid <siddhartha.bishnu@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 9df3d85 commit c68a8b0

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

src/TracedRNumber.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,27 @@ for (jlop, hloop) in (
505505
@eval $(jlop)(@nospecialize(lhs::TracedRNumber)) = @opcall $(hloop)(lhs)
506506
end
507507

508+
# Degree-based trigonometric wrappers for TracedRNumber
509+
# These convert to radians internally so Reactant can lower to
510+
# StableHLO-supported radian trigonometric operations.
511+
512+
Base.sind(x::TracedRNumber) = sin(deg2rad(x))
513+
Base.cosd(x::TracedRNumber) = cos(deg2rad(x))
514+
Base.tand(x::TracedRNumber) = tan(deg2rad(x))
515+
Base.cscd(x::TracedRNumber) = 1 / sind(x)
516+
Base.secd(x::TracedRNumber) = 1 / cosd(x)
517+
Base.cotd(x::TracedRNumber) = 1 / tand(x)
518+
519+
Base.asind(x::TracedRNumber) = rad2deg(asin(x))
520+
Base.acosd(x::TracedRNumber) = rad2deg(acos(x))
521+
Base.atand(x::TracedRNumber) = rad2deg(atan(x))
522+
523+
Base.atand(y::TracedRNumber, x::TracedRNumber) = rad2deg(atan(y, x))
524+
525+
Base.acscd(x::TracedRNumber) = rad2deg(asin(1 / x))
526+
Base.asecd(x::TracedRNumber) = rad2deg(acos(1 / x))
527+
Base.acotd(x::TracedRNumber) = rad2deg(atan(1 / x))
528+
508529
for (jlop, hloop) in (
509530
(:(Base.sin), :sine),
510531
(:(Base.cos), :cosine),

test/basic.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,13 +754,46 @@ end
754754
@test @jit(fn.(x_ra)) isa ConcreteRArray{Float32,2}
755755
end
756756

757+
x2 = inv.(x)
758+
x2_ra = Reactant.to_rarray(x2)
759+
760+
@testset for fn in (acscd, asecd)
761+
@test @jit(fn.(x2_ra)) fn.(x2)
762+
@test @jit(fn.(x2_ra)) isa ConcreteRArray{Float32,2}
763+
end
764+
765+
xrad = deg2rad.(x)
766+
xrad_ra = Reactant.to_rarray(xrad)
767+
768+
@testset for fn in (sind, cosd, tand, cscd, secd, cotd, asind, acosd, atand, acotd)
769+
@test @jit(fn.(xrad_ra)) fn.(xrad)
770+
@test @jit(fn.(xrad_ra)) isa ConcreteRArray{Float32,2}
771+
end
772+
757773
x = 0.235f0
758774
x_ra = Reactant.to_rarray(x; track_numbers=Number)
759775

760-
@testset for fn in (sinpi, cospi, tanpi, sin, cos, tan)
776+
@testset for fn in (sinpi, cospi, tanpi, sin, cos, tan, asind, acosd, atand, acotd)
761777
@test @jit(fn.(x_ra)) fn.(x)
762778
@test @jit(fn.(x_ra)) isa ConcreteRNumber{Float32}
763779
end
780+
781+
x2 = inv(x)
782+
x2_ra = Reactant.to_rarray(x2; track_numbers=Number)
783+
784+
@testset for fn in (acscd, asecd)
785+
@test @jit(fn.(x2_ra)) fn.(x2)
786+
@test @jit(fn.(x2_ra)) isa ConcreteRNumber{Float32}
787+
end
788+
789+
xrad = deg2rad(x)
790+
xrad_ra = Reactant.to_rarray(xrad; track_numbers=Number)
791+
792+
@testset for fn in (sind, cosd, tand, cscd, secd, cotd)
793+
@test @jit(fn.(xrad_ra)) fn.(xrad)
794+
@test @jit(fn.(xrad_ra)) isa ConcreteRNumber{Float32}
795+
end
796+
764797
@testset for fn in (sincospi, sincos)
765798
res = @jit fn(x_ra)
766799
@test res[1] fn(x)[1]

0 commit comments

Comments
 (0)