Skip to content

Commit 615105c

Browse files
committed
Fix epoch bug for dataset. #666
1 parent 878226b commit 615105c

File tree

6 files changed

+48
-44
lines changed

6 files changed

+48
-44
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
6+
namespace Tensorflow.Keras.Engine.DataAdapters
7+
{
8+
public abstract class DataAdapter
9+
{
10+
protected DataAdapterArgs args;
11+
protected IDatasetV2 dataset;
12+
13+
public virtual bool CanHandle(Tensor x, Tensor y = null)
14+
=> throw new NotImplementedException();
15+
16+
public virtual IDatasetV2 GetDataset()
17+
=> dataset;
18+
19+
public virtual int GetSize()
20+
=> throw new NotImplementedException("");
21+
22+
public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
23+
{
24+
if (y.TensorShape.ndim == 1)
25+
y = array_ops.expand_dims(y, axis: -1);
26+
return (x, y);
27+
}
28+
29+
public virtual bool ShouldRecreateIterator()
30+
{
31+
return true;
32+
}
33+
}
34+
}

src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,14 @@ int _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
9191

9292
public IEnumerable<(int, OwnedIterator)> enumerate_epochs()
9393
{
94-
using var ownedIterator = new OwnedIterator(_dataset);
94+
var data_iterator = new OwnedIterator(_dataset);
9595
foreach (var epoch in range(_initial_epoch, _epochs))
9696
{
9797
if (_insufficient_data)
9898
break;
99-
yield return (epoch, ownedIterator);
99+
if (_adapter.ShouldRecreateIterator())
100+
data_iterator = new OwnedIterator(_dataset);
101+
yield return (epoch, data_iterator);
100102
}
101103
}
102104

src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,15 @@
55

66
namespace Tensorflow.Keras.Engine.DataAdapters
77
{
8-
public class DatasetAdapter : IDataAdapter
8+
public class DatasetAdapter : DataAdapter, IDataAdapter
99
{
10-
DataAdapterArgs args;
11-
IDatasetV2 _dataset => args.Dataset;
1210
public DatasetAdapter(DataAdapterArgs args)
1311
{
1412
this.args = args;
13+
dataset = args.Dataset;
1514
}
1615

17-
public bool CanHandle(Tensor x, Tensor y = null)
18-
{
19-
throw new NotImplementedException();
20-
}
21-
22-
public IDatasetV2 GetDataset()
23-
=> _dataset;
24-
25-
public int GetSize()
16+
public override int GetSize()
2617
=> -1;
27-
28-
public (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
29-
{
30-
if (y.TensorShape.ndim == 1)
31-
y = array_ops.expand_dims(y, axis: -1);
32-
return (x, y);
33-
}
3418
}
3519
}

src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ public interface IDataAdapter
1717
IDatasetV2 GetDataset();
1818
int GetSize();
1919
(Tensor, Tensor) Expand1d(Tensor x, Tensor y);
20+
bool ShouldRecreateIterator();
2021
}
2122
}

src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters
77
/// <summary>
88
/// Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.
99
/// </summary>
10-
public class TensorLikeDataAdapter : IDataAdapter
10+
public class TensorLikeDataAdapter : DataAdapter, IDataAdapter
1111
{
12-
DataAdapterArgs args;
1312
int _size;
1413
int _batch_size;
1514
int num_samples;
1615
int num_full_batches;
17-
IDatasetV2 _dataset;
1816

1917
public TensorLikeDataAdapter(DataAdapterArgs args)
2018
{
@@ -31,7 +29,7 @@ public TensorLikeDataAdapter(DataAdapterArgs args)
3129
indices_dataset = indices_dataset.repeat();
3230
indices_dataset = indices_dataset.map(permutation).prefetch(1);
3331
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
34-
_dataset = slice_inputs(indices_dataset, args.X, args.Y);
32+
dataset = slice_inputs(indices_dataset, args.X, args.Y);
3533
}
3634

3735
Tensor permutation(Tensor tensor)
@@ -73,26 +71,11 @@ IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y)
7371
return dataset;
7472
}
7573

76-
public bool CanHandle(Tensor x, Tensor y = null)
77-
{
78-
throw new NotImplementedException();
79-
}
80-
81-
void _process_tensorlike()
82-
{
83-
}
84-
85-
public IDatasetV2 GetDataset()
86-
=> _dataset;
87-
88-
public int GetSize()
74+
public override int GetSize()
8975
=> _size;
9076

91-
public (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
77+
void _process_tensorlike()
9278
{
93-
if (y.TensorShape.ndim == 1)
94-
y = array_ops.expand_dims(y, axis: -1);
95-
return (x, y);
9679
}
9780
}
9881
}

src/TensorFlowNET.Keras/Utils/Web.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ public static bool Download(string url, string destDir, string destFileName)
4141
}
4242

4343
var wc = new WebClient();
44-
Console.WriteLine($"Downloading {relativeFilePath}");
44+
Console.WriteLine($"Downloading from {url}");
4545
var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
4646
while (!download.IsCompleted)
4747
{
4848
Thread.Sleep(1000);
4949
Console.Write(".");
5050
}
5151
Console.WriteLine("");
52-
Console.WriteLine($"Downloaded {relativeFilePath}");
52+
Console.WriteLine($"Downloaded to {relativeFilePath}");
5353

5454
return true;
5555
}

0 commit comments

Comments
 (0)