|
16 | 16 | from .array import * |
17 | 17 | from .util import * |
18 | 18 | from .util import _is_number |
| 19 | +from .random import randu, randn, set_seed, get_seed |
19 | 20 |
|
20 | 21 | def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32): |
21 | 22 | """ |
@@ -186,105 +187,6 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32) |
186 | 187 | 4, ct.pointer(tdims), dtype.value)) |
187 | 188 | return out |
188 | 189 |
|
189 | | -def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32): |
190 | | - """ |
191 | | - Create a multi dimensional array containing values from a uniform distribution. |
192 | | -
|
193 | | - Parameters |
194 | | - ---------- |
195 | | - d0 : int. |
196 | | - Length of first dimension. |
197 | | -
|
198 | | - d1 : optional: int. default: None. |
199 | | - Length of second dimension. |
200 | | -
|
201 | | - d2 : optional: int. default: None. |
202 | | - Length of third dimension. |
203 | | -
|
204 | | - d3 : optional: int. default: None. |
205 | | - Length of fourth dimension. |
206 | | -
|
207 | | - dtype : optional: af.Dtype. default: af.Dtype.f32. |
208 | | - Data type of the array. |
209 | | -
|
210 | | - Returns |
211 | | - ------- |
212 | | -
|
213 | | - out : af.Array |
214 | | - Multi dimensional array whose elements are sampled uniformly between [0, 1]. |
215 | | - - If d1 is None, `out` is 1D of size (d0,). |
216 | | - - If d1 is not None and d2 is None, `out` is 2D of size (d0, d1). |
217 | | - - If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2). |
218 | | - - If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3). |
219 | | - """ |
220 | | - out = Array() |
221 | | - dims = dim4(d0, d1, d2, d3) |
222 | | - |
223 | | - safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value)) |
224 | | - return out |
225 | | - |
226 | | -def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32): |
227 | | - """ |
228 | | - Create a multi dimensional array containing values from a normal distribution. |
229 | | -
|
230 | | - Parameters |
231 | | - ---------- |
232 | | - d0 : int. |
233 | | - Length of first dimension. |
234 | | -
|
235 | | - d1 : optional: int. default: None. |
236 | | - Length of second dimension. |
237 | | -
|
238 | | - d2 : optional: int. default: None. |
239 | | - Length of third dimension. |
240 | | -
|
241 | | - d3 : optional: int. default: None. |
242 | | - Length of fourth dimension. |
243 | | -
|
244 | | - dtype : optional: af.Dtype. default: af.Dtype.f32. |
245 | | - Data type of the array. |
246 | | -
|
247 | | - Returns |
248 | | - ------- |
249 | | -
|
250 | | - out : af.Array |
251 | | - Multi dimensional array whose elements are sampled from a normal distribution with mean 0 and sigma of 1. |
252 | | - - If d1 is None, `out` is 1D of size (d0,). |
253 | | - - If d1 is not None and d2 is None, `out` is 2D of size (d0, d1). |
254 | | - - If d1 and d2 are not None and d3 is None, `out` is 3D of size (d0, d1, d2). |
255 | | - - If d1, d2, d3 are all not None, `out` is 4D of size (d0, d1, d2, d3). |
256 | | - """ |
257 | | - |
258 | | - out = Array() |
259 | | - dims = dim4(d0, d1, d2, d3) |
260 | | - |
261 | | - safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value)) |
262 | | - return out |
263 | | - |
264 | | -def set_seed(seed=0): |
265 | | - """ |
266 | | - Set the seed for the random number generator. |
267 | | -
|
268 | | - Parameters |
269 | | - ---------- |
270 | | - seed: int. |
271 | | - Seed for the random number generator |
272 | | - """ |
273 | | - safe_call(backend.get().af_set_seed(ct.c_ulonglong(seed))) |
274 | | - |
275 | | -def get_seed(): |
276 | | - """ |
277 | | - Get the seed for the random number generator. |
278 | | -
|
279 | | - Returns |
280 | | - ---------- |
281 | | - seed: int. |
282 | | - Seed for the random number generator |
283 | | - """ |
284 | | - seed = ct.c_ulonglong(0) |
285 | | - safe_call(backend.get().af_get_seed(ct.pointer(seed))) |
286 | | - return seed.value |
287 | | - |
288 | 190 | def identity(d0, d1, d2=None, d3=None, dtype=Dtype.f32): |
289 | 191 | """ |
290 | 192 | Create an identity matrix or batch of identity matrices. |
|
0 commit comments