Skip to content

Commit 8c0feae

Browse files
committed
Fix image_dataset_from_directory #666
1 parent d8afa8c commit 8c0feae

File tree

7 files changed

+88
-92
lines changed

7 files changed

+88
-92
lines changed

src/TensorFlowNET.Core/Operations/image_ops_impl.cs

Lines changed: 63 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,74 +1702,79 @@ public static Tensor sobel_edges(Tensor image)
17021702
public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8,
17031703
string name = null, bool expand_animations = true)
17041704
{
1705-
Func<ITensorOrOperation> _jpeg = () =>
1705+
return tf_with(ops.name_scope(name, "decode_image"), scope =>
17061706
{
1707-
int jpeg_channels = channels;
1708-
var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels");
1709-
string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'";
1710-
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
1711-
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
1707+
var substr = tf.strings.substr(contents, 0, 3);
1708+
1709+
Func<ITensorOrOperation> _jpeg = () =>
17121710
{
1713-
return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype);
1714-
});
1715-
};
1711+
int jpeg_channels = channels;
1712+
var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels");
1713+
string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'";
1714+
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
1715+
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
1716+
{
1717+
return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype);
1718+
});
1719+
};
17161720

1717-
Func<ITensorOrOperation> _gif = () =>
1718-
{
1719-
int gif_channels = channels;
1720-
var good_channels = math_ops.logical_and(
1721-
math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"),
1722-
math_ops.not_equal(gif_channels, 4, name: "check_gif_channels"));
1723-
1724-
string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images";
1725-
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
1726-
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
1721+
/*Func<ITensorOrOperation> _gif = () =>
17271722
{
1728-
var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype);
1729-
if (!expand_animations)
1730-
result = array_ops.gather(result, 0);
1731-
return result;
1732-
});
1733-
};
1723+
int gif_channels = channels;
1724+
var good_channels = math_ops.logical_and(
1725+
math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"),
1726+
math_ops.not_equal(gif_channels, 4, name: "check_gif_channels"));
1727+
1728+
string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images";
1729+
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
1730+
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
1731+
{
1732+
var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype);
1733+
if (!expand_animations)
1734+
result = array_ops.gather(result, 0);
1735+
return result;
1736+
});
1737+
};
17341738
1735-
Func<ITensorOrOperation> _bmp = () =>
1736-
{
1737-
int bmp_channels = channels;
1738-
var signature = tf.strings.substr(contents, 0, 2);
1739-
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp");
1740-
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
1741-
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
1742-
var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels");
1743-
string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images";
1744-
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
1745-
return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate
1739+
Func<ITensorOrOperation> _bmp = () =>
17461740
{
1747-
return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype);
1748-
});
1749-
};
1741+
int bmp_channels = channels;
1742+
var signature = tf.strings.substr(contents, 0, 2);
1743+
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp");
1744+
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
1745+
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
1746+
var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels");
1747+
string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images";
1748+
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
1749+
return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate
1750+
{
1751+
return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype);
1752+
});
1753+
};
17501754
1751-
Func<ITensorOrOperation> _png = () =>
1752-
{
1753-
return convert_image_dtype(gen_image_ops.decode_png(
1754-
contents,
1755-
channels,
1756-
dtype: dtype),
1757-
dtype);
1758-
};
1755+
Func<ITensorOrOperation> _png = () =>
1756+
{
1757+
return convert_image_dtype(gen_image_ops.decode_png(
1758+
contents,
1759+
channels,
1760+
dtype: dtype),
1761+
dtype);
1762+
};
17591763
1760-
Func<ITensorOrOperation> check_gif = () =>
1761-
{
1762-
return control_flow_ops.cond(is_gif(contents), _gif, _bmp, name: "cond_gif");
1763-
};
1764+
Func<ITensorOrOperation> check_gif = () =>
1765+
{
1766+
var gif = tf.constant(new byte[] { 0x47, 0x49, 0x46 }, TF_DataType.TF_STRING);
1767+
var is_gif = math_ops.equal(substr, gif, name: name);
1768+
return control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif");
1769+
};
17641770
1765-
Func<ITensorOrOperation> check_png = () =>
1766-
{
1767-
return control_flow_ops.cond(is_png(contents), _png, check_gif, name: "cond_png");
1768-
};
1771+
Func<ITensorOrOperation> check_png = () =>
1772+
{
1773+
return control_flow_ops.cond(is_png(contents), _png, check_gif, name: "cond_png");
1774+
};*/
17691775

1770-
return tf_with(ops.name_scope(name, "decode_image"), scope =>
1771-
{
1772-
return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
1776+
// return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
1777+
return _jpeg() as Tensor;
17731778
});
17741779
}
17751780

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<TargetTensorFlow>2.2.0</TargetTensorFlow>
8-
<Version>0.31.1</Version>
8+
<Version>0.31.2</Version>
99
<LangVersion>8.0</LangVersion>
1010
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
1111
<Company>SciSharp STACK</Company>
@@ -19,7 +19,7 @@
1919
<Description>Google's TensorFlow full binding in .NET Standard.
2020
Building, training and infering deep learning models.
2121
https://tensorflownet.readthedocs.io</Description>
22-
<AssemblyVersion>0.31.1.0</AssemblyVersion>
22+
<AssemblyVersion>0.31.2.0</AssemblyVersion>
2323
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x.
2424

2525
* Eager Mode is added finally.
@@ -30,7 +30,7 @@ https://tensorflownet.readthedocs.io</Description>
3030
TensorFlow .NET v0.30 is focused on making more Keras API work including:
3131
* tf.keras.datasets
3232
* Building keras model in subclass, functional and sequential api</PackageReleaseNotes>
33-
<FileVersion>0.31.1.0</FileVersion>
33+
<FileVersion>0.31.2.0</FileVersion>
3434
<PackageLicenseFile>LICENSE</PackageLicenseFile>
3535
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
3636
<SignAssembly>true</SignAssembly>

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
using System.Linq;
2121
using System.Text;
2222
using Tensorflow.Eager;
23+
using static Tensorflow.Binding;
2324

2425
namespace Tensorflow
2526
{
@@ -410,14 +411,10 @@ bool hasattr(Graph property, string attr)
410411
var value = constant_value(tensor);
411412
if (!(value is null))
412413
{
413-
int[] d_ = { };
414-
foreach (int d in value)
415-
{
416-
if (d >= 0)
417-
d_[d_.Length] = d;
418-
else
419-
d_[d_.Length] = -1; // None
420-
}
414+
var d_ = new int[value.size];
415+
foreach (var (index, d) in enumerate(value.ToArray<int>()))
416+
d_[index] = d >= 0 ? d : -1;
417+
421418
ret = ret.merge_with(new TensorShape(d_));
422419
}
423420
return ret;

src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ public partial class DatasetUtils
4040
labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label));
4141
}
4242

43-
var return_labels = new int[labels.Count];
44-
var return_file_paths = new string[file_paths.Count];
43+
var return_labels = labels.Select(x => x).ToArray();
44+
var return_file_paths = file_paths.Select(x => x).ToArray();
4545

4646
if (shuffle)
4747
{

src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public IDatasetV2 image_dataset_from_directory(string directory,
4141
int num_channels = 0;
4242
if (color_mode == "rgb")
4343
num_channels = 3;
44-
// C:/Users/haipi/.keras/datasets/flower_photos
44+
4545
var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory,
4646
formats: WHITELIST_FORMATS,
4747
class_names: class_names,

src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,11 @@ public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths,
1616
var path_ds = tf.data.Dataset.from_tensor_slices(image_paths);
1717
var img_ds = path_ds.map(x => path_to_image(x, image_size, num_channels, interpolation));
1818

19-
/*Shape shape = (image_paths.Length, image_size.dims[0], image_size.dims[1], num_channels);
20-
Console.WriteLine($"Allocating memory for shape{shape}, {NPTypeCode.Float}");
21-
var data = np.zeros(shape, NPTypeCode.Float);
22-
23-
for (var i = 0; i < image_paths.Length; i++)
24-
{
25-
var image = path_to_image(image_paths[i], image_size, num_channels, interpolation);
26-
data[i] = image.numpy();
27-
if (i % 100 == 0)
28-
Console.WriteLine($"Filled {i}/{image_paths.Length} data into ndarray.");
29-
}
30-
31-
var img_ds = tf.data.Dataset.from_tensor_slices(data);
32-
3319
if (label_mode == "int")
3420
{
35-
var label_ds = tf.keras.preprocessing.dataset_utils.labels_to_dataset(labels, label_mode, num_classes);
21+
var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes);
3622
img_ds = tf.data.Dataset.zip(img_ds, label_ds);
3723
}
38-
else*/
39-
throw new NotImplementedException("");
4024

4125
return img_ds;
4226
}
@@ -47,6 +31,7 @@ Tensor path_to_image(Tensor path, TensorShape image_size, int num_channels, stri
4731
img = tf.image.decode_image(
4832
img, channels: num_channels, expand_animations: false);
4933
img = tf.image.resize_images_v2(img, image_size, method: interpolation);
34+
img.set_shape((image_size[0], image_size[1], num_channels));
5035
return img;
5136
}
5237
}

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<LangVersion>8.0</LangVersion>
77
<RootNamespace>Tensorflow.Keras</RootNamespace>
88
<Platforms>AnyCPU;x64</Platforms>
9-
<Version>0.2.1</Version>
9+
<Version>0.3.0</Version>
1010
<Authors>Haiping Chen</Authors>
1111
<Product>Keras for .NET</Product>
1212
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright>
@@ -25,11 +25,13 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
2525
<Company>SciSharp STACK</Company>
2626
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
2727
<PackageTags>tensorflow, keras, deep learning, machine learning</PackageTags>
28-
<PackageRequireLicenseAcceptance>false</PackageRequireLicenseAcceptance>
28+
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
2929
<RepositoryType>Git</RepositoryType>
3030
<SignAssembly>true</SignAssembly>
3131
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
32-
<AssemblyVersion>0.2.1.0</AssemblyVersion>
32+
<AssemblyVersion>0.3.0.0</AssemblyVersion>
33+
<FileVersion>0.3.0.0</FileVersion>
34+
<PackageLicenseFile>LICENSE</PackageLicenseFile>
3335
</PropertyGroup>
3436

3537
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
@@ -55,4 +57,11 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
5557
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
5658
</ItemGroup>
5759

60+
<ItemGroup>
61+
<None Include="..\..\LICENSE">
62+
<Pack>True</Pack>
63+
<PackagePath></PackagePath>
64+
</None>
65+
</ItemGroup>
66+
5867
</Project>

0 commit comments

Comments
 (0)