|
| 1 | +using System; |
| 2 | +using System.Collections.Generic; |
| 3 | +using Tensorflow.Keras.ArgsDefinition; |
| 4 | +using Tensorflow.Keras.Engine; |
| 5 | + |
| 6 | +namespace Tensorflow.Keras.Layers |
| 7 | +{ |
| 8 | + public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell |
| 9 | + { |
| 10 | + public IList<RnnCell> Cells { get; set; } |
| 11 | + |
| 12 | + public StackedRNNCells(StackedRNNCellsArgs args) : base(args) |
| 13 | + { |
| 14 | + Cells = args.Cells; |
| 15 | + //Cells.reverse_state_order = kwargs.pop('reverse_state_order', False); |
| 16 | + // self.reverse_state_order = kwargs.pop('reverse_state_order', False) |
| 17 | + // if self.reverse_state_order: |
| 18 | + // logging.warning('reverse_state_order=True in StackedRNNCells will soon ' |
| 19 | + // 'be deprecated. Please update the code to work with the ' |
| 20 | + // 'natural order of states if you rely on the RNN states, ' |
| 21 | + // 'eg RNN(return_state=True).') |
| 22 | + // super(StackedRNNCells, self).__init__(**kwargs) |
| 23 | + throw new NotImplementedException(""); |
| 24 | + } |
| 25 | + |
| 26 | + public object state_size |
| 27 | + { |
| 28 | + get => throw new NotImplementedException(); |
| 29 | + } |
| 30 | + |
| 31 | + //@property |
| 32 | + //def state_size(self) : |
| 33 | + // return tuple(c.state_size for c in |
| 34 | + // (self.cells[::- 1] if self.reverse_state_order else self.cells)) |
| 35 | + |
| 36 | + // @property |
| 37 | + // def output_size(self) : |
| 38 | + // if getattr(self.cells[-1], 'output_size', None) is not None: |
| 39 | + // return self.cells[-1].output_size |
| 40 | + // elif _is_multiple_state(self.cells[-1].state_size) : |
| 41 | + // return self.cells[-1].state_size[0] |
| 42 | + // else: |
| 43 | + // return self.cells[-1].state_size |
| 44 | + |
| 45 | + // def get_initial_state(self, inputs= None, batch_size= None, dtype= None) : |
| 46 | + // initial_states = [] |
| 47 | + // for cell in self.cells[::- 1] if self.reverse_state_order else self.cells: |
| 48 | + // get_initial_state_fn = getattr(cell, 'get_initial_state', None) |
| 49 | + // if get_initial_state_fn: |
| 50 | + // initial_states.append(get_initial_state_fn( |
| 51 | + // inputs=inputs, batch_size=batch_size, dtype=dtype)) |
| 52 | + // else: |
| 53 | + // initial_states.append(_generate_zero_filled_state_for_cell( |
| 54 | + // cell, inputs, batch_size, dtype)) |
| 55 | + |
| 56 | + // return tuple(initial_states) |
| 57 | + |
| 58 | + // def call(self, inputs, states, constants= None, training= None, ** kwargs): |
| 59 | + // # Recover per-cell states. |
| 60 | + // state_size = (self.state_size[::- 1] |
| 61 | + // if self.reverse_state_order else self.state_size) |
| 62 | + // nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) |
| 63 | + |
| 64 | + // # Call the cells in order and store the returned states. |
| 65 | + // new_nested_states = [] |
| 66 | + // for cell, states in zip(self.cells, nested_states) : |
| 67 | + // states = states if nest.is_nested(states) else [states] |
| 68 | + //# TF cell does not wrap the state into list when there is only one state. |
| 69 | + // is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None |
| 70 | + // states = states[0] if len(states) == 1 and is_tf_rnn_cell else states |
| 71 | + // if generic_utils.has_arg(cell.call, 'training'): |
| 72 | + // kwargs['training'] = training |
| 73 | + // else: |
| 74 | + // kwargs.pop('training', None) |
| 75 | + // # Use the __call__ function for callable objects, eg layers, so that it |
| 76 | + // # will have the proper name scopes for the ops, etc. |
| 77 | + // cell_call_fn = cell.__call__ if callable(cell) else cell.call |
| 78 | + // if generic_utils.has_arg(cell.call, 'constants'): |
| 79 | + // inputs, states = cell_call_fn(inputs, states, |
| 80 | + // constants= constants, ** kwargs) |
| 81 | + // else: |
| 82 | + // inputs, states = cell_call_fn(inputs, states, ** kwargs) |
| 83 | + // new_nested_states.append(states) |
| 84 | + |
| 85 | + // return inputs, nest.pack_sequence_as(state_size, |
| 86 | + // nest.flatten(new_nested_states)) |
| 87 | + |
| 88 | + // @tf_utils.shape_type_conversion |
| 89 | + // def build(self, input_shape) : |
| 90 | + // if isinstance(input_shape, list) : |
| 91 | + // input_shape = input_shape[0] |
| 92 | + // for cell in self.cells: |
| 93 | + // if isinstance(cell, Layer) and not cell.built: |
| 94 | + // with K.name_scope(cell.name): |
| 95 | + // cell.build(input_shape) |
| 96 | + // cell.built = True |
| 97 | + // if getattr(cell, 'output_size', None) is not None: |
| 98 | + // output_dim = cell.output_size |
| 99 | + // elif _is_multiple_state(cell.state_size) : |
| 100 | + // output_dim = cell.state_size[0] |
| 101 | + // else: |
| 102 | + // output_dim = cell.state_size |
| 103 | + // input_shape = tuple([input_shape[0]] + |
| 104 | + // tensor_shape.TensorShape(output_dim).as_list()) |
| 105 | + // self.built = True |
| 106 | + |
| 107 | + // def get_config(self) : |
| 108 | + // cells = [] |
| 109 | + // for cell in self.cells: |
| 110 | + // cells.append(generic_utils.serialize_keras_object(cell)) |
| 111 | + // config = {'cells': cells |
| 112 | + //} |
| 113 | + //base_config = super(StackedRNNCells, self).get_config() |
| 114 | + // return dict(list(base_config.items()) + list(config.items())) |
| 115 | + |
| 116 | + // @classmethod |
| 117 | + // def from_config(cls, config, custom_objects = None): |
| 118 | + // from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top |
| 119 | + // cells = [] |
| 120 | + // for cell_config in config.pop('cells'): |
| 121 | + // cells.append( |
| 122 | + // deserialize_layer(cell_config, custom_objects = custom_objects)) |
| 123 | + // return cls(cells, **config) |
| 124 | + } |
| 125 | +} |
0 commit comments