Skip to content

Commit a70077b

Browse files
committed
BasicLSTMCell
1 parent f15a608 commit a70077b

File tree

12 files changed

+260
-7
lines changed

12 files changed

+260
-7
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ public Tensor _clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_
251251
/// greater than <c>clip_value_max</c> are set to <c>clip_value_max</c>.
252252
/// </remarks>
253253
public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue")
254-
=> gen_ops.clip_by_value(t, clip_value_min, clip_value_max, name);
254+
=> clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name);
255255

256256
public Tensor sub(Tensor a, Tensor b)
257257
=> gen_math_ops.sub(a, b);

src/TensorFlowNET.Core/Framework/tensor_shape.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ public static void assert_is_compatible_with(this Tensor self, Tensor other)
2424
}
2525
}
2626

27+
public static Dimension dimension_at_index(TensorShape shape, int index)
28+
{
29+
return shape.rank < 0 ?
30+
new Dimension(-1) :
31+
new Dimension(shape.dims[index]);
32+
}
33+
34+
public static int dimension_value(Dimension dimension)
35+
=> dimension.value;
36+
2737
public static TensorShape as_shape(this Shape shape)
2838
=> new TensorShape(shape.Dimensions);
2939
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
using static Tensorflow.Binding;
7+
using Tensorflow.Operations.Activation;
8+
using Tensorflow.Keras.Engine;
9+
using Tensorflow.Operations;
10+
11+
namespace Tensorflow
12+
{
13+
/// <summary>
14+
/// Basic LSTM recurrent network cell.
15+
/// The implementation is based on: http://arxiv.org/abs/1409.2329.
16+
/// </summary>
17+
public class BasicLSTMCell : LayerRnnCell
18+
{
19+
int _num_units;
20+
float _forget_bias;
21+
bool _state_is_tuple;
22+
IActivation _activation;
23+
24+
/// <summary>
25+
/// Initialize the basic LSTM cell.
26+
/// </summary>
27+
/// <param name="num_units">The number of units in the LSTM cell.</param>
28+
/// <param name="forget_bias"></param>
29+
/// <param name="state_is_tuple"></param>
30+
/// <param name="activation"></param>
31+
/// <param name="reuse"></param>
32+
/// <param name="name"></param>
33+
/// <param name="dtype"></param>
34+
public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true,
35+
IActivation activation = null, bool? reuse = null, string name = null,
36+
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype)
37+
{
38+
input_spec = new InputSpec(ndim: 2);
39+
_num_units = num_units;
40+
_forget_bias = forget_bias;
41+
_state_is_tuple = state_is_tuple;
42+
_activation = activation;
43+
if (_activation == null)
44+
_activation = tf.nn.tanh();
45+
}
46+
47+
public LSTMStateTuple state_size
48+
{
49+
get
50+
{
51+
return _state_is_tuple ?
52+
new LSTMStateTuple(_num_units, _num_units) :
53+
(LSTMStateTuple)(2 * _num_units);
54+
}
55+
}
56+
}
57+
}

src/TensorFlowNET.Core/Operations/BasicRNNCell.cs renamed to src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using System;
1818
using Tensorflow.Keras.Engine;
19+
using Tensorflow.Operations;
1920
using static Tensorflow.Binding;
2021

2122
namespace Tensorflow
@@ -25,7 +26,7 @@ public class BasicRnnCell : LayerRnnCell
2526
int _num_units;
2627
Func<Tensor, string, Tensor> _activation;
2728

28-
public override int state_size => _num_units;
29+
public override LSTMStateTuple state_size => _num_units;
2930
public override int output_size => _num_units;
3031
public VariableV1 _kernel;
3132
string _WEIGHTS_VARIABLE_NAME = "kernel";
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations
6+
{
7+
/// <summary>
8+
/// Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
9+
///
10+
/// Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state
11+
/// and `h` is the output.
12+
///
13+
/// Only used when `state_is_tuple=True`.
14+
/// </summary>
15+
public class LSTMStateTuple
16+
{
17+
int c;
18+
int h;
19+
20+
public LSTMStateTuple(int c)
21+
{
22+
this.c = c;
23+
}
24+
25+
public LSTMStateTuple(int c, int h)
26+
{
27+
this.c = c;
28+
this.h = h;
29+
}
30+
31+
public static implicit operator int(LSTMStateTuple tuple)
32+
{
33+
return tuple.c;
34+
}
35+
36+
public static implicit operator LSTMStateTuple(int c)
37+
{
38+
return new LSTMStateTuple(c);
39+
}
40+
}
41+
}

src/TensorFlowNET.Core/Operations/RNNCell.cs renamed to src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public abstract class RnnCell : Layers.Layer
4949
/// difference between TF and Keras RNN cell.
5050
/// </summary>
5151
protected bool _is_tf_rnn_cell = false;
52-
public virtual int state_size { get; }
52+
public virtual LSTMStateTuple state_size { get; }
5353

5454
public virtual int output_size { get; }
5555

src/TensorFlowNET.Core/Operations/NnOps/rnn.cs

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,106 @@ limitations under the License.
1818
using System;
1919
using System.Collections.Generic;
2020
using System.Linq;
21+
using Tensorflow.Framework;
2122
using Tensorflow.Util;
2223
using static Tensorflow.Binding;
2324

2425
namespace Tensorflow.Operations
2526
{
26-
internal class rnn
27+
public class rnn
2728
{
29+
/// <summary>
30+
/// Creates a bidirectional recurrent neural network.
31+
/// </summary>
32+
public static void static_bidirectional_rnn(BasicLSTMCell cell_fw,
33+
BasicLSTMCell cell_bw,
34+
Tensor[] inputs,
35+
Tensor initial_state_fw = null,
36+
Tensor initial_state_bw = null,
37+
TF_DataType dtype = TF_DataType.DtInvalid,
38+
Tensor sequence_length = null,
39+
string scope = null)
40+
{
41+
if (inputs == null || inputs.Length == 0)
42+
throw new ValueError("inputs must not be empty");
43+
44+
tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate
45+
{
46+
// Forward direction
47+
tf_with(tf.variable_scope("fw"), fw_scope =>
48+
{
49+
static_rnn(
50+
cell_fw,
51+
inputs,
52+
initial_state_fw,
53+
dtype,
54+
sequence_length,
55+
scope: fw_scope);
56+
});
57+
});
58+
}
59+
60+
public static void static_rnn(BasicLSTMCell cell,
61+
Tensor[] inputs,
62+
Tensor initial_state,
63+
TF_DataType dtype = TF_DataType.DtInvalid,
64+
Tensor sequence_length = null,
65+
VariableScope scope = null)
66+
{
67+
// Create a new scope in which the caching device is either
68+
// determined by the parent scope, or is set to place the cached
69+
// Variable using the same placement as for the rest of the RNN.
70+
if (scope == null)
71+
tf_with(tf.variable_scope("rnn"), varscope =>
72+
{
73+
throw new NotImplementedException("static_rnn");
74+
});
75+
else
76+
tf_with(tf.variable_scope(scope), varscope =>
77+
{
78+
Dimension fixed_batch_size = null;
79+
Dimension batch_size = null;
80+
Tensor batch_size_tensor = null;
81+
82+
// Obtain the first sequence of the input
83+
var first_input = inputs[0];
84+
if (first_input.TensorShape.rank != 1)
85+
{
86+
var input_shape = first_input.TensorShape.with_rank_at_least(2);
87+
fixed_batch_size = input_shape.dims[0];
88+
var flat_inputs = nest.flatten2(inputs);
89+
foreach (var flat_input in flat_inputs)
90+
{
91+
input_shape = flat_input.TensorShape.with_rank_at_least(2);
92+
batch_size = tensor_shape.dimension_at_index(input_shape, 0);
93+
var input_size = input_shape[1];
94+
fixed_batch_size.merge_with(batch_size);
95+
foreach (var (i, size) in enumerate(input_size.dims))
96+
{
97+
if (size < 0)
98+
throw new ValueError($"Input size (dimension {i} of inputs) must be accessible via " +
99+
"shape inference, but saw value None.");
100+
}
101+
}
102+
}
103+
else
104+
fixed_batch_size = first_input.TensorShape.with_rank_at_least(1).dims[0];
105+
106+
if (tensor_shape.dimension_value(fixed_batch_size) >= 0)
107+
batch_size = tensor_shape.dimension_value(fixed_batch_size);
108+
else
109+
batch_size_tensor = array_ops.shape(first_input)[0];
110+
111+
Tensor state = null;
112+
if (initial_state != null)
113+
state = initial_state;
114+
else
115+
{
116+
cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype);
117+
}
118+
});
119+
}
120+
28121
public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor,
29122
Tensor sequence_length = null, Tensor initial_state = null,
30123
TF_DataType dtype = TF_DataType.DtInvalid,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using System.Collections.Generic;
19+
using System.Linq;
20+
using System.Text;
21+
using System.Threading.Tasks;
22+
using static Tensorflow.Binding;
23+
24+
namespace Tensorflow
25+
{
26+
public class clip_ops
27+
{
28+
public static Tensor clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null)
29+
{
30+
return tf_with(ops.name_scope(name, "clip_by_value", new { t, clip_value_min, clip_value_max }), delegate
31+
{
32+
var values = ops.convert_to_tensor(t, name: "t");
33+
// Go through list of tensors, for each value in each tensor clip
34+
var t_min = math_ops.minimum(values, clip_value_max);
35+
// Assert that the shape is compatible with the initial shape,
36+
// to prevent unintentional broadcasting.
37+
_ = values.TensorShape.merge_with(t_min.shape);
38+
var t_max = math_ops.maximum(t_min, clip_value_min, name: name);
39+
_ = values.TensorShape.merge_with(t_max.shape);
40+
41+
return t_max;
42+
});
43+
}
44+
}
45+
}

src/TensorFlowNET.Core/TensorFlow.Binding.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
4-
<TargetFrameworks>net472;netstandard2.0</TargetFrameworks>
4+
<TargetFramework>netstandard2.0</TargetFramework>
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<TargetTensorFlow>1.14.1</TargetTensorFlow>

0 commit comments

Comments
 (0)