File tree Expand file tree Collapse file tree 11 files changed +428
-18
lines changed
TensorFlowNET.Keras/Layers
test/TensorFlowNET.Keras.UnitTest/Layers Expand file tree Collapse file tree 11 files changed +428
-18
lines changed Original file line number Diff line number Diff line change 1+ using Newtonsoft . Json ;
2+ using System ;
3+ using System . Collections . Generic ;
4+ using System . Text ;
5+ using Tensorflow . NumPy ;
6+
7+ namespace Tensorflow . Keras . ArgsDefinition
8+ {
9+ public class BidirectionalArgs : AutoSerializeLayerArgs
10+ {
11+ [ JsonProperty ( "layer" ) ]
12+ public ILayer Layer { get ; set ; }
13+ [ JsonProperty ( "merge_mode" ) ]
14+ public string ? MergeMode { get ; set ; }
15+ [ JsonProperty ( "backward_layer" ) ]
16+ public ILayer BackwardLayer { get ; set ; }
17+ public NDArray Weights { get ; set ; }
18+ }
19+
20+ }
Original file line number Diff line number Diff line change @@ -5,5 +5,10 @@ public class LSTMArgs : RNNArgs
55 // TODO: maybe change the `RNNArgs` and implement this class.
66 public bool UnitForgetBias { get ; set ; }
77 public int Implementation { get ; set ; }
8+
9+ public LSTMArgs Clone ( )
10+ {
11+ return ( LSTMArgs ) MemberwiseClone ( ) ;
12+ }
813 }
914}
Original file line number Diff line number Diff line change @@ -40,5 +40,10 @@ public class RNNArgs : AutoSerializeLayerArgs
4040 public bool ZeroOutputForMask { get ; set ; } = false ;
4141 [ JsonProperty ( "recurrent_dropout" ) ]
4242 public float RecurrentDropout { get ; set ; } = .0f ;
43+
44+ public RNNArgs Clone ( )
45+ {
46+ return ( RNNArgs ) MemberwiseClone ( ) ;
47+ }
4348 }
4449}
Original file line number Diff line number Diff line change 1+ using Newtonsoft . Json ;
2+ using System ;
3+ using System . Collections . Generic ;
4+ using System . Runtime . CompilerServices ;
5+ using System . Text ;
6+
7+
8+ namespace Tensorflow . Keras . ArgsDefinition
9+ {
10+ public class WrapperArgs : AutoSerializeLayerArgs
11+ {
12+ [ JsonProperty ( "layer" ) ]
13+ public ILayer Layer { get ; set ; }
14+
15+ public WrapperArgs ( ILayer layer )
16+ {
17+ Layer = layer ;
18+ }
19+
20+ public static implicit operator WrapperArgs ( BidirectionalArgs args )
21+ => new WrapperArgs ( args . Layer ) ;
22+ }
23+
24+ }
Original file line number Diff line number Diff line change @@ -258,7 +258,19 @@ public IRnnCell GRUCell(
258258 float dropout = 0f ,
259259 float recurrent_dropout = 0f ,
260260 bool reset_after = true ) ;
261-
261+
262+ /// <summary>
263+ /// Bidirectional wrapper for RNNs.
264+ /// </summary>
265+ /// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param>
266+ /// automatically.</param>
267+ /// <returns></returns>
268+ public ILayer Bidirectional (
269+ ILayer layer ,
270+ string merge_mode = "concat" ,
271+ NDArray weights = null ,
272+ ILayer backward_layer = null ) ;
273+
262274 public ILayer Subtract ( ) ;
263275 }
264276}
Original file line number Diff line number Diff line change @@ -908,6 +908,20 @@ public IRnnCell GRUCell(
908908 ResetAfter = reset_after
909909 } ) ;
910910
911+ public ILayer Bidirectional (
912+ ILayer layer ,
913+ string merge_mode = "concat" ,
914+ NDArray weights = null ,
915+ ILayer backward_layer = null )
916+ => new Bidirectional ( new BidirectionalArgs
917+ {
918+ Layer = layer ,
919+ MergeMode = merge_mode ,
920+ Weights = weights ,
921+ BackwardLayer = backward_layer
922+ } ) ;
923+
924+
911925 /// <summary>
912926 ///
913927 /// </summary>
Original file line number Diff line number Diff line change 1+ using System ;
2+ using System . Collections . Generic ;
3+ using System . Diagnostics ;
4+ using System . Text ;
5+ using Tensorflow . Keras . ArgsDefinition ;
6+ using Tensorflow . Keras . Saving ;
7+
8+ namespace Tensorflow . Keras . Layers
9+ {
10+ /// <summary>
11+ /// Abstract wrapper base class. Wrappers take another layer and augment it in various ways.
12+ /// Do not use this class as a layer, it is only an abstract base class.
13+ /// Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
14+ /// </summary>
15+ public abstract class Wrapper : Layer
16+ {
17+ public ILayer _layer ;
18+ public Wrapper ( WrapperArgs args ) : base ( args )
19+ {
20+ _layer = args . Layer ;
21+ }
22+
23+ public virtual void Build ( KerasShapesWrapper input_shape )
24+ {
25+ if ( ! _layer . Built )
26+ {
27+ _layer . build ( input_shape ) ;
28+ }
29+ built = true ;
30+ }
31+
32+ }
33+ }
You can’t perform that action at this time.
0 commit comments