@@ -81,7 +81,7 @@ julia> Flux.glorot_uniform(2, 3)
8181[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
8282"""
8383glorot_uniform (rng:: AbstractRNG , dims... ) = (rand (rng, Float32, dims... ) .- 0.5f0 ) .* sqrt (24.0f0 / sum (nfan (dims... )))
84- glorot_uniform (dims... ) = glorot_uniform (Random . GLOBAL_RNG , dims... )
84+ glorot_uniform (dims... ) = glorot_uniform (rng_from_array () , dims... )
8585glorot_uniform (rng:: AbstractRNG ) = (dims... ) -> glorot_uniform (rng, dims... )
8686
8787"""
@@ -114,7 +114,7 @@ julia> Flux.glorot_normal(3, 2)
114114[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
115115"""
116116glorot_normal (rng:: AbstractRNG , dims... ) = randn (rng, Float32, dims... ) .* sqrt (2.0f0 / sum (nfan (dims... )))
117- glorot_normal (dims... ) = glorot_normal (Random . GLOBAL_RNG , dims... )
117+ glorot_normal (dims... ) = glorot_normal (rng_from_array () , dims... )
118118glorot_normal (rng:: AbstractRNG ) = (dims... ) -> glorot_normal (rng, dims... )
119119
120120"""
@@ -151,7 +151,7 @@ function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2)
151151 return (rand (rng, Float32, dims... ) .- 0.5f0 ) .* 2 bound
152152end
153153
154- kaiming_uniform (dims... ; kwargs... ) = kaiming_uniform (Random . GLOBAL_RNG , dims... ; kwargs... )
154+ kaiming_uniform (dims... ; kwargs... ) = kaiming_uniform (rng_from_array () , dims... ; kwargs... )
155155kaiming_uniform (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_uniform (rng, dims... ; init_kwargs... , kwargs... )
156156
157157"""
@@ -188,9 +188,58 @@ function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0)
188188 return randn (rng, Float32, dims... ) .* std
189189end
190190
191- kaiming_normal (dims... ; kwargs... ) = kaiming_normal (Random . GLOBAL_RNG , dims... ; kwargs... )
191+ kaiming_normal (dims... ; kwargs... ) = kaiming_normal (rng_from_array () , dims... ; kwargs... )
192192kaiming_normal (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_normal (rng, dims... ; init_kwargs... , kwargs... )
193193
194+ """
195+ truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2)
196+
197+ Return an `Array{Float32}` of size `dims` where each element is drawn from a truncated normal distribution.
198+ The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(dims...))`.
199+
200+ The values are generated by sampling a Uniform(0, 1) (`rand()`) and then
201+ applying the inverse CDF of the truncated normal distribution
202+ (see the references for more info).
203+ This method works best when `lo ≤ mean ≤ hi`.
204+
205+ # Examples
206+ ```jldoctest
207+ julia> using Statistics
208+
209+ julia> Flux.truncated_normal(3, 4) |> summary
210+ "3×4 Matrix{Float32}"
211+
212+ julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3)
213+ (-2.0f0, 2.0f0)
214+
215+ julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
216+ 1.0f0
217+ ```
218+
219+ # References
220+ [1] Burkardt, John. "The Truncated Normal Distribution"
221+ [PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf).
222+ Department of Scientific Computing website.
223+ """
224+ function truncated_normal (rng:: AbstractRNG , dims... ; mean = 0 , std = 1 , lo = - 2 , hi = 2 )
225+ norm_cdf (x) = 0.5 * (1 + erf (x/√ 2 ))
226+ if (mean < lo - 2 * std) || (mean > hi + 2 * std)
227+ @warn " Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog= 1
228+ end
229+ l = norm_cdf ((lo - mean) / std)
230+ u = norm_cdf ((hi - mean) / std)
231+ xs = rand (rng, Float32, dims... )
232+ broadcast! (xs, xs) do x
233+ x = x * 2 (u - l) + (2 l - 1 )
234+ x = erfinv (x)
235+ x = clamp (x * std * √ 2 + mean, lo, hi)
236+ end
237+ return xs
238+ end
239+
240+ truncated_normal (dims... ; kwargs... ) = truncated_normal (rng_from_array (), dims... ; kwargs... )
241+ truncated_normal (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> truncated_normal (rng, dims... ; init_kwargs... , kwargs... )
242+
194243"""
195244 orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)
196245
232281* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)
233282
234283# References
284+
235285[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
236286
237287"""
@@ -254,7 +304,7 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
254304 return reshape (orthogonal (rng, rows, cols; kwargs... ), dims)
255305end
256306
257- orthogonal (dims:: Integer... ; kwargs... ) = orthogonal (Random . GLOBAL_RNG , dims... ; kwargs... )
307+ orthogonal (dims:: Integer... ; kwargs... ) = orthogonal (rng_from_array () , dims... ; kwargs... )
258308orthogonal (rng:: AbstractRNG ; init_kwargs... ) = (dims:: Integer... ; kwargs... ) -> orthogonal (rng, dims... ; init_kwargs... , kwargs... )
259309
260310"""
@@ -298,7 +348,7 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
298348 return mapslices (shuffle, sparse_array, dims= 1 )
299349end
300350
301- sparse_init (dims... ; kwargs... ) = sparse_init (Random . GLOBAL_RNG , dims... ; kwargs... )
351+ sparse_init (dims... ; kwargs... ) = sparse_init (rng_from_array () , dims... ; kwargs... )
302352sparse_init (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> sparse_init (rng, dims... ; init_kwargs... , kwargs... )
303353
304354"""
@@ -382,7 +432,7 @@ function identity_init(dims...; gain=1, shift=0)
382432end
383433
384434identity_init (:: AbstractRNG , dims... ; kwargs... ) = identity_init (dims... ; kwargs... )
385- identity_init (; init_kwargs... ) = identity_init (Random . GLOBAL_RNG ; init_kwargs... )
435+ identity_init (; init_kwargs... ) = identity_init (rng_from_array () ; init_kwargs... )
386436identity_init (rng:: AbstractRNG ; init_kwargs... ) = (args... ;kwargs... ) -> identity_init (rng, args... ; init_kwargs... , kwargs... )
387437
388438ones32 (dims... ) = Base. ones (Float32, dims... )
437487
438488Flatten a model's parameters into a single weight vector.
439489
440- julia> m = Chain(Dense(10, 5, σ ), Dense(5, 2), softmax)
441- Chain(Dense(10, 5, σ ), Dense(5, 2), softmax)
490+ julia> m = Chain(Dense(10, 5, std ), Dense(5, 2), softmax)
491+ Chain(Dense(10, 5, std ), Dense(5, 2), softmax)
442492
443493 julia> θ, re = destructure(m);
444494
0 commit comments