|
| 1 | +using System; |
| 2 | +using System.Collections.Generic; |
| 3 | +using System.Text; |
| 4 | +using NumSharp; |
| 5 | + |
| 6 | +namespace Tensorflow.Hub |
| 7 | +{ |
| 8 | + public class Datasets<TDataSet> where TDataSet : IDataSet |
| 9 | + { |
| 10 | + public TDataSet Train { get; private set; } |
| 11 | + |
| 12 | + public TDataSet Validation { get; private set; } |
| 13 | + |
| 14 | + public TDataSet Test { get; private set; } |
| 15 | + |
| 16 | + public Datasets(TDataSet train, TDataSet validation, TDataSet test) |
| 17 | + { |
| 18 | + Train = train; |
| 19 | + Validation = validation; |
| 20 | + Test = test; |
| 21 | + } |
| 22 | + |
| 23 | + public (NDArray, NDArray) Randomize(NDArray x, NDArray y) |
| 24 | + { |
| 25 | + var perm = np.random.permutation(y.shape[0]); |
| 26 | + np.random.shuffle(perm); |
| 27 | + return (x[perm], y[perm]); |
| 28 | + } |
| 29 | + |
| 30 | + /// <summary> |
| 31 | + /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) |
| 32 | + /// </summary> |
| 33 | + /// <param name="x"></param> |
| 34 | + /// <param name="y"></param> |
| 35 | + /// <param name="start"></param> |
| 36 | + /// <param name="end"></param> |
| 37 | + /// <returns></returns> |
| 38 | + public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) |
| 39 | + { |
| 40 | + var slice = new Slice(start, end); |
| 41 | + var x_batch = x[slice]; |
| 42 | + var y_batch = y[slice]; |
| 43 | + return (x_batch, y_batch); |
| 44 | + } |
| 45 | + } |
| 46 | +} |
0 commit comments