Skip to content

Commit 5ee46e4

Browse files
committed
tf.while_loop #348
1 parent ded16ea commit 5ee46e4

File tree

19 files changed

+150
-80
lines changed

19 files changed

+150
-80
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
18+
using System.Runtime.InteropServices;
19+
20+
namespace Tensorflow
21+
{
22+
public partial class c_api
23+
{
24+
/// <summary>
25+
/// Specify the device for `desc`. Defaults to empty, meaning unconstrained.
26+
/// </summary>
27+
/// <param name="desc"></param>
28+
/// <param name="device"></param>
29+
[DllImport(TensorFlowLibName)]
30+
public static extern void TF_SetDevice(IntPtr desc, string device);
31+
}
32+
}

src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ public _ControlDependenciesController(Graph graph, List<ITensorOrOperation> cont
6969
_new_stack = false;
7070
}
7171

72-
_seen_nodes = new List<ITensorOrOperation>();
72+
_seen_nodes = new List<ITensorOrOperation>();
73+
_old_stack = null;
74+
_old_control_flow_context = null;
7375
}
7476

7577
public void add_op(ITensorOrOperation op)

src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ protected override void build(TensorShape input_shape)
139139
built = true;
140140
}
141141

142-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
142+
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
143143
{
144144
Tensor outputs = null;
145145

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ protected override void build(TensorShape input_shape)
108108
built = true;
109109
}
110110

111-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
111+
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
112112
{
113113
var outputs = _convolution_op.__call__(inputs, kernel);
114114
if (use_bias)

src/TensorFlowNET.Core/Keras/Layers/Dense.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ protected override void build(TensorShape input_shape)
7272
built = true;
7373
}
7474

75-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
75+
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
7676
{
7777
Tensor outputs = null;
7878
var rank = inputs.rank;

src/TensorFlowNET.Core/Keras/Layers/Embedding.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ protected override void build(TensorShape input_shape)
5050
built = true;
5151
}
5252

53-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
53+
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
5454
{
5555
var dtype = inputs.dtype;
5656
if (dtype != tf.int32 && dtype != tf.int64)

src/TensorFlowNET.Core/Keras/Layers/Layer.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ public class Layer : AutoTrackable
5252
protected InputSpec input_spec;
5353
protected bool supports_masking;
5454
protected List<VariableV1> _trainable_weights;
55+
protected List<VariableV1> _non_trainable_weights;
5556
private string _name;
5657
public string name => _name;
5758
protected string _base_name;
@@ -84,6 +85,7 @@ public Layer(bool trainable = true,
8485

8586
_init_set_name(name);
8687
_trainable_weights = new List<VariableV1>();
88+
_non_trainable_weights = new List<VariableV1>();
8789
_compute_previous_mask = false;
8890
_updates = new List<Operation>();
8991

@@ -103,6 +105,7 @@ public Layer(bool trainable = true,
103105

104106
public (Tensor, Tensor) __call__(Tensor[] inputs,
105107
Tensor training = null,
108+
Tensor state = null,
106109
VariableScope scope = null)
107110
{
108111
var input_list = inputs;
@@ -139,7 +142,9 @@ public Layer(bool trainable = true,
139142
// overridden).
140143
_maybe_build(inputs[0]);
141144

142-
(input, outputs) = call(inputs[0], training: training);
145+
(input, outputs) = call(inputs[0],
146+
training: training,
147+
state: state);
143148
(input, outputs) = _set_connectivity_metadata_(input, outputs);
144149
_handle_activity_regularization(inputs[0], outputs);
145150
_set_mask_metadata(inputs[0], outputs, null);
@@ -173,7 +178,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
173178
return null;
174179
}
175180

176-
protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
181+
protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
177182
{
178183
return (inputs, inputs);
179184
}
@@ -233,7 +238,10 @@ protected virtual VariableV1 add_weight(string name,
233238
initializer: initializer,
234239
trainable: trainable.Value);
235240
//backend.track_variable(variable);
236-
_trainable_weights.Add(variable);
241+
if (trainable == true)
242+
_trainable_weights.Add(variable);
243+
else
244+
_non_trainable_weights.Add(variable);
237245

238246
return variable;
239247
}

src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public Pooling2D(IPoolFunction pool_function,
4343
this.input_spec = new InputSpec(ndim: 4);
4444
}
4545

46-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
46+
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
4747
{
4848
int[] pool_shape;
4949
if (data_format == "channels_last")

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public Layer(bool trainable = true,
4343

4444
// Avoid an incorrect lint error
4545
_trainable_weights = new List<VariableV1>();
46+
_non_trainable_weights = new List<VariableV1>();
4647
this.built = false;
4748
_keras_style = false;
4849
}
@@ -54,6 +55,7 @@ public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null)
5455

5556
public (Tensor, Tensor) __call__(Tensor inputs,
5657
Tensor training = null,
58+
Tensor state = null,
5759
VariableScope scope = null)
5860
{
5961
_set_scope(scope);
@@ -76,7 +78,9 @@ public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null)
7678
{
7779
_current_scope = scope2;
7880
// Actually call layer
79-
outputs = base.__call__(new Tensor[] { inputs }, training: training);
81+
outputs = base.__call__(new Tensor[] { inputs },
82+
training: training,
83+
state: state);
8084
});
8185

8286

@@ -121,6 +125,11 @@ protected virtual VariableV1 add_weight(string name,
121125
Graph init_graph = null;
122126
VariableV1[] existing_variables = null;
123127

128+
if (synchronization == VariableSynchronization.OnRead)
129+
trainable = false;
130+
else if (!trainable.HasValue)
131+
trainable = true;
132+
124133
if (default_graph.building_function)
125134
{
126135
throw new NotImplementedException("add_weight");

src/TensorFlowNET.Core/Operations/BasicRNNCell.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ protected override void build(TensorShape inputs_shape)
6666
built = true;
6767
}
6868

69-
protected override (Tensor, Tensor) call(Tensor inputs, Tensor state = null)
69+
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
7070
{
7171
// Most basic RNN: output = new_state = act(W * input + U * state + B).
7272
var concat = array_ops.concat(new[] { inputs, state }, 1);

0 commit comments

Comments
 (0)