Skip to content

Commit ccd6fb9

Browse files
authored
fix: ensure randn respects exclusivity of maxval (#1815)
* fix: ensure randn respects exclusivity of maxval * feat: allow traced inputs
1 parent c68a8b0 commit ccd6fb9

File tree

2 files changed

+88
-20
lines changed

2 files changed

+88
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
4-
version = "0.2.174"
4+
version = "0.2.175"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/Ops.jl

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,13 +1482,15 @@ end
14821482
::Type{T},
14831483
seed::TracedRArray{UInt64,1},
14841484
shape;
1485+
minval::Union{T,Nothing}=nothing,
1486+
maxval::Union{T,Nothing}=nothing,
14851487
algorithm::String="DEFAULT",
14861488
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
14871489
)
14881490
14891491
Generate a random array of type `T` with the given shape and seed from a uniform random
1490-
distribution between 0 and 1 (for floating point types). Returns a NamedTuple with the
1491-
following fields:
1492+
distribution between `[minval, maxval)` (for floating point types). Returns a NamedTuple
1493+
with the following fields:
14921494
14931495
- `output_state`: The state of the random number generator after the operation.
14941496
- `output`: The generated array.
@@ -1498,17 +1500,25 @@ following fields:
14981500
- `T`: The type of the generated array.
14991501
- `seed`: The seed for the random number generator.
15001502
- `shape`: The shape of the generated array.
1503+
- `minval`: The minimum value of the generated random numbers. (Only for floating point
1504+
types). Defaults to `0`.
1505+
- `maxval`: The maximum value of the generated random numbers. (Only for floating point
1506+
types). Defaults to `1`.
15011507
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
15021508
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
15031509
"""
15041510
@noinline function rng_bit_generator(
15051511
::Type{T},
15061512
seed::TracedRArray{UInt64,1},
15071513
shape;
1514+
minval::Union{T,Nothing}=nothing,
1515+
maxval::Union{T,Nothing}=nothing,
15081516
algorithm::String="DEFAULT",
15091517
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
15101518
) where {T<:Integer}
15111519
@assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY")
1520+
@assert minval === nothing "minval is not supported for integer rng_bit_generator"
1521+
@assert maxval === nothing "maxval is not supported for integer rng_bit_generator"
15121522
if algorithm == "PHILOX"
15131523
@assert length(seed) (2, 3)
15141524
elseif algorithm == "THREE_FRY"
@@ -1527,35 +1537,70 @@ following fields:
15271537
)
15281538
end
15291539

1540+
function _get_uint_from_bitwidth(width::Int)
1541+
@assert width (8, 16, 32, 64) "Unsupported bitwidth: $width"
1542+
return width == 8 ? UInt8 : (width == 16 ? UInt16 : (width == 32 ? UInt32 : UInt64))
1543+
end
1544+
15301545
# https://github.com/jax-ml/jax/blob/474dcd409d6fa4c048014851922460f9d4fc199e/jax/_src/random.py#L444-L464
15311546
@noinline function rng_bit_generator(
15321547
::Type{T},
15331548
seed::TracedRArray{UInt64,1},
15341549
shape;
1550+
minval::T=T(0),
1551+
maxval::T=T(1),
15351552
algorithm::String="DEFAULT",
15361553
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
15371554
) where {T<:AbstractFloat}
15381555
nbits = sizeof(T) * 8
1539-
@assert nbits (8, 16, 32, 64) "Unsupported type: $(T)"
1540-
uT = nbits == 8 ? UInt8 : (nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64))
1541-
(; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
1556+
nmantissa = Reactant.nmantissa(T)
1557+
rng_bits = nbits
1558+
nmantissa < 8 && (rng_bits = 8)
1559+
uint_gen_dtype = _get_uint_from_bitwidth(rng_bits)
1560+
(; output_state, output) = rng_bit_generator(
1561+
uint_gen_dtype, seed, shape; algorithm, location
1562+
)
1563+
uint_dtype = _get_uint_from_bitwidth(nbits)
1564+
bits = output
1565+
if rng_bits != nbits
1566+
bits = convert(TracedRArray{uint_dtype,length(shape)}, bits)
1567+
end
1568+
15421569
float_bits = or(
15431570
shift_right_logical(
1544-
output,
1545-
fill(uT(nbits - Reactant.nmantissa(T)), size(output); location);
1546-
location,
1571+
bits, fill(uint_dtype(rng_bits - nmantissa), size(bits); location); location
15471572
),
1548-
fill(reinterpret(uT, T(1)), size(output); location);
1573+
fill(reinterpret(uint_dtype, T(1)), size(bits); location);
15491574
location,
15501575
)
1551-
output = subtract(
1576+
floats = subtract(
15521577
bitcast_convert(TracedRArray{T,length(shape)}, float_bits; location),
15531578
fill(T(1), size(output); location);
15541579
location,
15551580
)
1581+
1582+
maxval = prevfloat(maxval) # make maxval exclusive
1583+
minval_ = fill(minval, size(floats); location)
1584+
maxval_ = fill(maxval, size(floats); location)
1585+
output = clamp(
1586+
minval_,
1587+
add(
1588+
multiply(floats, subtract(maxval_, minval_; location); location),
1589+
minval_;
1590+
location,
1591+
),
1592+
maxval_;
1593+
location,
1594+
)
15561595
return (; output_state, output)
15571596
end
15581597

1598+
@noinline function rng_bit_generator(
1599+
::Type{TracedRNumber{T}}, seed::TracedRArray{UInt64,1}, shape; kwargs...
1600+
) where {T}
1601+
return rng_bit_generator(T, seed, shape; kwargs...)
1602+
end
1603+
15591604
"""
15601605
randn(
15611606
::Type{T},
@@ -1585,20 +1630,37 @@ fields:
15851630
seed::TracedRArray{UInt64,1},
15861631
shape;
15871632
algorithm::String="DEFAULT",
1588-
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
1589-
) where {T}
1590-
res = rng_bit_generator(T, seed, shape; algorithm, location)
1633+
location=mlir_stacktrace("randn", @__FILE__, @__LINE__),
1634+
) where {T<:AbstractFloat}
1635+
res = rng_bit_generator(
1636+
T, seed, shape; algorithm, location, minval=nextfloat(T(-1)), maxval=T(1)
1637+
)
15911638
rand_uniform = res.output
15921639
seed = res.output_state
1593-
scaled_uniform = subtract(
1594-
multiply(rand_uniform, fill(T(2), size(rand_uniform))),
1595-
fill(T(1), size(rand_uniform)),
1596-
)
1597-
probit = erf_inv(scaled_uniform)
1640+
probit = erf_inv(rand_uniform)
15981641
rand_normal = multiply(probit, fill(Base.sqrt(T(2)), size(rand_uniform)))
15991642
return (; output_state=seed, output=rand_normal)
16001643
end
16011644

1645+
@noinline function randn(
1646+
::Type{Complex{T}},
1647+
seed::TracedRArray{UInt64,1},
1648+
shape;
1649+
algorithm::String="DEFAULT",
1650+
location=mlir_stacktrace("randn", @__FILE__, @__LINE__),
1651+
) where {T<:AbstractFloat}
1652+
real_result = randn(T, seed, shape; algorithm, location)
1653+
imag_result = randn(T, real_result.output_state, shape; algorithm, location)
1654+
output = complex.(real_result.output, imag_result.output)
1655+
return (; output_state=imag_result.output_state, output)
1656+
end
1657+
1658+
@noinline function randn(
1659+
::Type{TracedRNumber{T}}, seed::TracedRArray{UInt64,1}, shape; kwargs...
1660+
) where {T}
1661+
return randn(T, seed, shape; kwargs...)
1662+
end
1663+
16021664
"""
16031665
randexp(
16041666
::Type{T},
@@ -1628,14 +1690,20 @@ distribution with rate 1. Returns a NamedTuple with the following fields:
16281690
shape;
16291691
algorithm::String="DEFAULT",
16301692
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
1631-
) where {T}
1693+
) where {T<:AbstractFloat}
16321694
res = rng_bit_generator(T, seed, shape; algorithm, location)
16331695
rand_uniform = res.output
16341696
seed = res.output_state
16351697
rand_exp = negate(log_plus_one(negate(rand_uniform)))
16361698
return (; output_state=seed, output=rand_exp)
16371699
end
16381700

1701+
@noinline function randexp(
1702+
::Type{TracedRNumber{T}}, seed::TracedRArray{UInt64,1}, shape; kwargs...
1703+
) where {T}
1704+
return randexp(T, seed, shape; kwargs...)
1705+
end
1706+
16391707
# functional ops
16401708
@noinline function return_(
16411709
results::Union{TracedRArray,TracedRNumber}...;

0 commit comments

Comments
 (0)