@@ -1482,13 +1482,15 @@ end
14821482 ::Type{T},
14831483 seed::TracedRArray{UInt64,1},
14841484 shape;
1485+ minval::Union{T,Nothing}=nothing,
1486+ maxval::Union{T,Nothing}=nothing,
14851487 algorithm::String="DEFAULT",
14861488 location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
14871489 )
14881490
14891491Generate a random array of type `T` with the given shape and seed from a uniform random
1490- distribution between 0 and 1 (for floating point types). Returns a NamedTuple with the
1491- following fields:
1492+ distribution between `[minval, maxval)` (for floating point types). Returns a NamedTuple
1493+ with the following fields:
14921494
14931495- `output_state`: The state of the random number generator after the operation.
14941496- `output`: The generated array.
@@ -1498,17 +1500,25 @@ following fields:
14981500- `T`: The type of the generated array.
14991501- `seed`: The seed for the random number generator.
15001502- `shape`: The shape of the generated array.
1503+ - `minval`: The minimum value of the generated random numbers. (Only for floating point
1504+ types). Defaults to `0`.
1505+ - `maxval`: The maximum value of the generated random numbers. (Only for floating point
1506+ types). Defaults to `1`.
15011507- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
15021508 "DEFAULT". Other options include "PHILOX" and "THREE_FRY".
15031509"""
15041510@noinline function rng_bit_generator (
15051511 :: Type{T} ,
15061512 seed:: TracedRArray{UInt64,1} ,
15071513 shape;
1514+ minval:: Union{T,Nothing} = nothing ,
1515+ maxval:: Union{T,Nothing} = nothing ,
15081516 algorithm:: String = " DEFAULT" ,
15091517 location= mlir_stacktrace (" rng_bit_generator" , @__FILE__ , @__LINE__ ),
15101518) where {T<: Integer }
15111519 @assert algorithm in (" DEFAULT" , " PHILOX" , " THREE_FRY" )
1520+ @assert minval === nothing " minval is not supported for integer rng_bit_generator"
1521+ @assert maxval === nothing " maxval is not supported for integer rng_bit_generator"
15121522 if algorithm == " PHILOX"
15131523 @assert length (seed) ∈ (2 , 3 )
15141524 elseif algorithm == " THREE_FRY"
@@ -1527,35 +1537,70 @@ following fields:
15271537 )
15281538end
15291539
1540+ function _get_uint_from_bitwidth (width:: Int )
1541+ @assert width ∈ (8 , 16 , 32 , 64 ) " Unsupported bitwidth: $width "
1542+ return width == 8 ? UInt8 : (width == 16 ? UInt16 : (width == 32 ? UInt32 : UInt64))
1543+ end
1544+
15301545# https://github.com/jax-ml/jax/blob/474dcd409d6fa4c048014851922460f9d4fc199e/jax/_src/random.py#L444-L464
15311546@noinline function rng_bit_generator (
15321547 :: Type{T} ,
15331548 seed:: TracedRArray{UInt64,1} ,
15341549 shape;
1550+ minval:: T = T (0 ),
1551+ maxval:: T = T (1 ),
15351552 algorithm:: String = " DEFAULT" ,
15361553 location= mlir_stacktrace (" rng_bit_generator" , @__FILE__ , @__LINE__ ),
15371554) where {T<: AbstractFloat }
15381555 nbits = sizeof (T) * 8
1539- @assert nbits ∈ (8 , 16 , 32 , 64 ) " Unsupported type: $(T) "
1540- uT = nbits == 8 ? UInt8 : (nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64))
1541- (; output_state, output) = rng_bit_generator (uT, seed, shape; algorithm, location)
1556+ nmantissa = Reactant. nmantissa (T)
1557+ rng_bits = nbits
1558+ nmantissa < 8 && (rng_bits = 8 )
1559+ uint_gen_dtype = _get_uint_from_bitwidth (rng_bits)
1560+ (; output_state, output) = rng_bit_generator (
1561+ uint_gen_dtype, seed, shape; algorithm, location
1562+ )
1563+ uint_dtype = _get_uint_from_bitwidth (nbits)
1564+ bits = output
1565+ if rng_bits != nbits
1566+ bits = convert (TracedRArray{uint_dtype,length (shape)}, bits)
1567+ end
1568+
15421569 float_bits = or (
15431570 shift_right_logical (
1544- output,
1545- fill (uT (nbits - Reactant. nmantissa (T)), size (output); location);
1546- location,
1571+ bits, fill (uint_dtype (rng_bits - nmantissa), size (bits); location); location
15471572 ),
1548- fill (reinterpret (uT , T (1 )), size (output ); location);
1573+ fill (reinterpret (uint_dtype , T (1 )), size (bits ); location);
15491574 location,
15501575 )
1551- output = subtract (
1576+ floats = subtract (
15521577 bitcast_convert (TracedRArray{T,length (shape)}, float_bits; location),
15531578 fill (T (1 ), size (output); location);
15541579 location,
15551580 )
1581+
1582+ maxval = prevfloat (maxval) # make maxval exclusive
1583+ minval_ = fill (minval, size (floats); location)
1584+ maxval_ = fill (maxval, size (floats); location)
1585+ output = clamp (
1586+ minval_,
1587+ add (
1588+ multiply (floats, subtract (maxval_, minval_; location); location),
1589+ minval_;
1590+ location,
1591+ ),
1592+ maxval_;
1593+ location,
1594+ )
15561595 return (; output_state, output)
15571596end
15581597
1598+ @noinline function rng_bit_generator (
1599+ :: Type{TracedRNumber{T}} , seed:: TracedRArray{UInt64,1} , shape; kwargs...
1600+ ) where {T}
1601+ return rng_bit_generator (T, seed, shape; kwargs... )
1602+ end
1603+
15591604"""
15601605 randn(
15611606 ::Type{T},
@@ -1585,20 +1630,37 @@ fields:
15851630 seed:: TracedRArray{UInt64,1} ,
15861631 shape;
15871632 algorithm:: String = " DEFAULT" ,
1588- location= mlir_stacktrace (" rand" , @__FILE__ , @__LINE__ ),
1589- ) where {T}
1590- res = rng_bit_generator (T, seed, shape; algorithm, location)
1633+ location= mlir_stacktrace (" randn" , @__FILE__ , @__LINE__ ),
1634+ ) where {T<: AbstractFloat }
1635+ res = rng_bit_generator (
1636+ T, seed, shape; algorithm, location, minval= nextfloat (T (- 1 )), maxval= T (1 )
1637+ )
15911638 rand_uniform = res. output
15921639 seed = res. output_state
1593- scaled_uniform = subtract (
1594- multiply (rand_uniform, fill (T (2 ), size (rand_uniform))),
1595- fill (T (1 ), size (rand_uniform)),
1596- )
1597- probit = erf_inv (scaled_uniform)
1640+ probit = erf_inv (rand_uniform)
15981641 rand_normal = multiply (probit, fill (Base. sqrt (T (2 )), size (rand_uniform)))
15991642 return (; output_state= seed, output= rand_normal)
16001643end
16011644
1645+ @noinline function randn (
1646+ :: Type{Complex{T}} ,
1647+ seed:: TracedRArray{UInt64,1} ,
1648+ shape;
1649+ algorithm:: String = " DEFAULT" ,
1650+ location= mlir_stacktrace (" randn" , @__FILE__ , @__LINE__ ),
1651+ ) where {T<: AbstractFloat }
1652+ real_result = randn (T, seed, shape; algorithm, location)
1653+ imag_result = randn (T, real_result. output_state, shape; algorithm, location)
1654+ output = complex .(real_result. output, imag_result. output)
1655+ return (; output_state= imag_result. output_state, output)
1656+ end
1657+
1658+ @noinline function randn (
1659+ :: Type{TracedRNumber{T}} , seed:: TracedRArray{UInt64,1} , shape; kwargs...
1660+ ) where {T}
1661+ return randn (T, seed, shape; kwargs... )
1662+ end
1663+
16021664"""
16031665 randexp(
16041666 ::Type{T},
@@ -1628,14 +1690,20 @@ distribution with rate 1. Returns a NamedTuple with the following fields:
16281690 shape;
16291691 algorithm:: String = " DEFAULT" ,
16301692 location= mlir_stacktrace (" rand" , @__FILE__ , @__LINE__ ),
1631- ) where {T}
1693+ ) where {T<: AbstractFloat }
16321694 res = rng_bit_generator (T, seed, shape; algorithm, location)
16331695 rand_uniform = res. output
16341696 seed = res. output_state
16351697 rand_exp = negate (log_plus_one (negate (rand_uniform)))
16361698 return (; output_state= seed, output= rand_exp)
16371699end
16381700
1701+ @noinline function randexp (
1702+ :: Type{TracedRNumber{T}} , seed:: TracedRArray{UInt64,1} , shape; kwargs...
1703+ ) where {T}
1704+ return randexp (T, seed, shape; kwargs... )
1705+ end
1706+
16391707# functional ops
16401708@noinline function return_ (
16411709 results:: Union{TracedRArray,TracedRNumber} ...;
0 commit comments