Skip to content

Commit d89609a

Browse files
MPnoyOceania2018
authored andcommitted
Blank SimpleRNN and test for it
1 parent 5801566 commit d89609a

File tree

9 files changed

+307
-8
lines changed

9 files changed

+307
-8
lines changed
Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
1-
namespace Tensorflow.Keras.ArgsDefinition
1+
using System.Collections.Generic;
2+
3+
namespace Tensorflow.Keras.ArgsDefinition
24
{
35
public class RNNArgs : LayerArgs
46
{
7+
public interface IRnnArgCell : ILayer
8+
{
9+
object state_size { get; }
10+
}
11+
12+
public IRnnArgCell Cell { get; set; } = null;
13+
public bool ReturnSequences { get; set; } = false;
14+
public bool ReturnState { get; set; } = false;
15+
public bool GoBackwards { get; set; } = false;
16+
public bool Stateful { get; set; } = false;
17+
public bool Unroll { get; set; } = false;
18+
public bool TimeMajor { get; set; } = false;
19+
public Dictionary<string, object> Kwargs { get; set; } = null;
520
}
621
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
namespace Tensorflow.Keras.ArgsDefinition
2+
{
3+
public class SimpleRNNArgs : RNNArgs
4+
{
5+
public int Units { get; set; }
6+
public Activation Activation { get; set; }
7+
8+
// units,
9+
// activation='tanh',
10+
// use_bias=True,
11+
// kernel_initializer='glorot_uniform',
12+
// recurrent_initializer='orthogonal',
13+
// bias_initializer='zeros',
14+
// kernel_regularizer=None,
15+
// recurrent_regularizer=None,
16+
// bias_regularizer=None,
17+
// activity_regularizer=None,
18+
// kernel_constraint=None,
19+
// recurrent_constraint=None,
20+
// bias_constraint=None,
21+
// dropout=0.,
22+
// recurrent_dropout=0.,
23+
// return_sequences=False,
24+
// return_state=False,
25+
// go_backwards=False,
26+
// stateful=False,
27+
// unroll=False,
28+
// **kwargs):
29+
}
30+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using System.Collections.Generic;
2+
3+
namespace Tensorflow.Keras.ArgsDefinition
4+
{
5+
public class StackedRNNCellsArgs : LayerArgs
6+
{
7+
public IList<RnnCell> Cells { get; set; }
8+
}
9+
}

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace Tensorflow
4646
/// matching structure of Tensors having shape `[batch_size].concatenate(s)`
4747
/// for each `s` in `self.batch_size`.
4848
/// </summary>
49-
public abstract class RnnCell : ILayer
49+
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
5050
{
5151
/// <summary>
5252
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using NumSharp;
2+
using System.Collections.Generic;
23
using Tensorflow.Keras.ArgsDefinition;
34
using Tensorflow.Keras.Engine;
45
using static Tensorflow.Binding;
@@ -327,6 +328,24 @@ public Layer LeakyReLU(float alpha = 0.3f)
327328
Alpha = alpha
328329
});
329330

331+
public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh");
332+
333+
public Layer SimpleRNN(int units,
334+
Activation activation = null)
335+
=> new SimpleRNN(new SimpleRNNArgs
336+
{
337+
Units = units,
338+
Activation = activation
339+
});
340+
341+
public Layer SimpleRNN(int units,
342+
string activation = "tanh")
343+
=> new SimpleRNN(new SimpleRNNArgs
344+
{
345+
Units = units,
346+
Activation = GetActivationByName(activation)
347+
});
348+
330349
public Layer LSTM(int units,
331350
Activation activation = null,
332351
Activation recurrent_activation = null,

src/TensorFlowNET.Keras/Layers/RNN.cs

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,99 @@
11
using System;
2+
using System.Collections.Generic;
23
using Tensorflow.Keras.ArgsDefinition;
34
using Tensorflow.Keras.Engine;
45

56
namespace Tensorflow.Keras.Layers
67
{
78
public class RNN : Layer
89
{
9-
public RNN(RNNArgs args)
10-
: base(args)
10+
private RNNArgs args;
11+
12+
public RNN(RNNArgs args) : base(PreConstruct(args))
1113
{
14+
this.args = args;
15+
SupportsMasking = true;
16+
17+
// The input shape is unknown yet, it could have nested tensor inputs, and
18+
// the input spec will be the list of specs for nested inputs, the structure
19+
// of the input_spec will be the same as the input.
1220

21+
//self.input_spec = None
22+
//self.state_spec = None
23+
//self._states = None
24+
//self.constants_spec = None
25+
//self._num_constants = 0
26+
27+
//if stateful:
28+
// if ds_context.has_strategy():
29+
// raise ValueError('RNNs with stateful=True not yet supported with '
30+
// 'tf.distribute.Strategy.')
1331
}
1432

33+
private static RNNArgs PreConstruct(RNNArgs args)
34+
{
35+
if (args.Kwargs == null)
36+
{
37+
args.Kwargs = new Dictionary<string, object>();
38+
}
39+
40+
// If true, the output for masked timestep will be zeros, whereas in the
41+
// false case, output from previous timestep is returned for masked timestep.
42+
var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false);
43+
44+
object input_shape;
45+
var propIS = args.Kwargs.Get("input_shape", null);
46+
var propID = args.Kwargs.Get("input_dim", null);
47+
var propIL = args.Kwargs.Get("input_length", null);
48+
49+
if (propIS == null && (propID != null || propIL != null))
50+
{
51+
input_shape = (
52+
propIL ?? new NoneValue(), // maybe null is needed here
53+
propID ?? new NoneValue()); // and here
54+
args.Kwargs["input_shape"] = input_shape;
55+
}
56+
57+
return args;
58+
}
59+
60+
public RNN New(LayerRnnCell cell,
61+
bool return_sequences = false,
62+
bool return_state = false,
63+
bool go_backwards = false,
64+
bool stateful = false,
65+
bool unroll = false,
66+
bool time_major = false)
67+
=> new RNN(new RNNArgs
68+
{
69+
Cell = cell,
70+
ReturnSequences = return_sequences,
71+
ReturnState = return_state,
72+
GoBackwards = go_backwards,
73+
Stateful = stateful,
74+
Unroll = unroll,
75+
TimeMajor = time_major
76+
});
77+
78+
public RNN New(IList<RnnCell> cell,
79+
bool return_sequences = false,
80+
bool return_state = false,
81+
bool go_backwards = false,
82+
bool stateful = false,
83+
bool unroll = false,
84+
bool time_major = false)
85+
=> new RNN(new RNNArgs
86+
{
87+
Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }),
88+
ReturnSequences = return_sequences,
89+
ReturnState = return_state,
90+
GoBackwards = go_backwards,
91+
Stateful = stateful,
92+
Unroll = unroll,
93+
TimeMajor = time_major
94+
});
95+
96+
1597
protected Tensor get_initial_state(Tensor inputs)
1698
{
1799
return _generate_zero_filled_state_for_cell(null, null);
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using Tensorflow.Keras.ArgsDefinition;
2+
3+
namespace Tensorflow.Keras.Layers
4+
{
5+
public class SimpleRNN : RNN
6+
{
7+
8+
public SimpleRNN(RNNArgs args) : base(args)
9+
{
10+
11+
}
12+
13+
}
14+
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Tensorflow.Keras.ArgsDefinition;
4+
using Tensorflow.Keras.Engine;
5+
6+
namespace Tensorflow.Keras.Layers
7+
{
8+
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell
9+
{
10+
public IList<RnnCell> Cells { get; set; }
11+
12+
public StackedRNNCells(StackedRNNCellsArgs args) : base(args)
13+
{
14+
Cells = args.Cells;
15+
//Cells.reverse_state_order = kwargs.pop('reverse_state_order', False);
16+
// self.reverse_state_order = kwargs.pop('reverse_state_order', False)
17+
// if self.reverse_state_order:
18+
// logging.warning('reverse_state_order=True in StackedRNNCells will soon '
19+
// 'be deprecated. Please update the code to work with the '
20+
// 'natural order of states if you rely on the RNN states, '
21+
// 'eg RNN(return_state=True).')
22+
// super(StackedRNNCells, self).__init__(**kwargs)
23+
throw new NotImplementedException("");
24+
}
25+
26+
public object state_size
27+
{
28+
get => throw new NotImplementedException();
29+
}
30+
31+
//@property
32+
//def state_size(self) :
33+
// return tuple(c.state_size for c in
34+
// (self.cells[::- 1] if self.reverse_state_order else self.cells))
35+
36+
// @property
37+
// def output_size(self) :
38+
// if getattr(self.cells[-1], 'output_size', None) is not None:
39+
// return self.cells[-1].output_size
40+
// elif _is_multiple_state(self.cells[-1].state_size) :
41+
// return self.cells[-1].state_size[0]
42+
// else:
43+
// return self.cells[-1].state_size
44+
45+
// def get_initial_state(self, inputs= None, batch_size= None, dtype= None) :
46+
// initial_states = []
47+
// for cell in self.cells[::- 1] if self.reverse_state_order else self.cells:
48+
// get_initial_state_fn = getattr(cell, 'get_initial_state', None)
49+
// if get_initial_state_fn:
50+
// initial_states.append(get_initial_state_fn(
51+
// inputs=inputs, batch_size=batch_size, dtype=dtype))
52+
// else:
53+
// initial_states.append(_generate_zero_filled_state_for_cell(
54+
// cell, inputs, batch_size, dtype))
55+
56+
// return tuple(initial_states)
57+
58+
// def call(self, inputs, states, constants= None, training= None, ** kwargs):
59+
// # Recover per-cell states.
60+
// state_size = (self.state_size[::- 1]
61+
// if self.reverse_state_order else self.state_size)
62+
// nested_states = nest.pack_sequence_as(state_size, nest.flatten(states))
63+
64+
// # Call the cells in order and store the returned states.
65+
// new_nested_states = []
66+
// for cell, states in zip(self.cells, nested_states) :
67+
// states = states if nest.is_nested(states) else [states]
68+
//# TF cell does not wrap the state into list when there is only one state.
69+
// is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
70+
// states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
71+
// if generic_utils.has_arg(cell.call, 'training'):
72+
// kwargs['training'] = training
73+
// else:
74+
// kwargs.pop('training', None)
75+
// # Use the __call__ function for callable objects, eg layers, so that it
76+
// # will have the proper name scopes for the ops, etc.
77+
// cell_call_fn = cell.__call__ if callable(cell) else cell.call
78+
// if generic_utils.has_arg(cell.call, 'constants'):
79+
// inputs, states = cell_call_fn(inputs, states,
80+
// constants= constants, ** kwargs)
81+
// else:
82+
// inputs, states = cell_call_fn(inputs, states, ** kwargs)
83+
// new_nested_states.append(states)
84+
85+
// return inputs, nest.pack_sequence_as(state_size,
86+
// nest.flatten(new_nested_states))
87+
88+
// @tf_utils.shape_type_conversion
89+
// def build(self, input_shape) :
90+
// if isinstance(input_shape, list) :
91+
// input_shape = input_shape[0]
92+
// for cell in self.cells:
93+
// if isinstance(cell, Layer) and not cell.built:
94+
// with K.name_scope(cell.name):
95+
// cell.build(input_shape)
96+
// cell.built = True
97+
// if getattr(cell, 'output_size', None) is not None:
98+
// output_dim = cell.output_size
99+
// elif _is_multiple_state(cell.state_size) :
100+
// output_dim = cell.state_size[0]
101+
// else:
102+
// output_dim = cell.state_size
103+
// input_shape = tuple([input_shape[0]] +
104+
// tensor_shape.TensorShape(output_dim).as_list())
105+
// self.built = True
106+
107+
// def get_config(self) :
108+
// cells = []
109+
// for cell in self.cells:
110+
// cells.append(generic_utils.serialize_keras_object(cell))
111+
// config = {'cells': cells
112+
//}
113+
//base_config = super(StackedRNNCells, self).get_config()
114+
// return dict(list(base_config.items()) + list(config.items()))
115+
116+
// @classmethod
117+
// def from_config(cls, config, custom_objects = None):
118+
// from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
119+
// cells = []
120+
// for cell_config in config.pop('cells'):
121+
// cells.append(
122+
// deserialize_layer(cell_config, custom_objects = custom_objects))
123+
// return cls(cells, **config)
124+
}
125+
}

test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public void Functional()
3636
var model = keras.Model(inputs, outputs, name: "mnist_model");
3737
model.summary();
3838
}
39-
39+
4040
/// <summary>
4141
/// Custom layer test, used in Dueling DQN
4242
/// </summary>
@@ -45,10 +45,10 @@ public void TensorFlowOpLayer()
4545
{
4646
var layers = keras.layers;
4747
var inputs = layers.Input(shape: 24);
48-
var x = layers.Dense(128, activation:"relu").Apply(inputs);
48+
var x = layers.Dense(128, activation: "relu").Apply(inputs);
4949
var value = layers.Dense(24).Apply(x);
5050
var adv = layers.Dense(1).Apply(x);
51-
51+
5252
var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true);
5353
adv = layers.Subtract().Apply((adv, mean));
5454
var outputs = layers.Add().Apply((value, adv));
@@ -105,9 +105,14 @@ public void Dense()
105105
}
106106

107107
[TestMethod]
108+
[Ignore]
108109
public void SimpleRNN()
109110
{
110-
111+
var inputs = np.random.rand(32, 10, 8).astype(np.float32);
112+
var simple_rnn = keras.layers.SimpleRNN(4);
113+
var output = simple_rnn.Apply(inputs);
114+
Assert.AreEqual((32, 4), output.shape);
111115
}
116+
112117
}
113118
}

0 commit comments

Comments
 (0)