Skip to content

Commit 9f0c1e5

Browse files
committed
fix _compute_fans when more than 4 dimensions.
1 parent 6129940 commit 9f0c1e5

File tree

7 files changed

+72
-8
lines changed

7 files changed

+72
-8
lines changed

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,21 @@ protected virtual void build(TensorShape input_shape)
5757
{
5858

5959
}
60+
61+
protected virtual void add_weight(string name,
62+
int[] shape,
63+
TF_DataType dtype = TF_DataType.DtInvalid,
64+
IInitializer initializer = null,
65+
bool? trainable = null,
66+
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null)
67+
{
68+
_add_variable_with_custom_getter(name,
69+
shape,
70+
dtype: dtype,
71+
getter: getter,
72+
overwrite: true,
73+
initializer: initializer,
74+
trainable: trainable.Value);
75+
}
6076
}
6177
}

src/TensorFlowNET.Core/Keras/Layers/Conv.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ protected override void build(TensorShape input_shape)
5353
int channel_axis = data_format == "channels_first" ? 1 : -1;
5454
int input_dim = input_shape.Dimensions[input_shape.NDim - 1];
5555
var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters };
56-
add_weight();
56+
add_weight(name: "kernel",
57+
shape: kernel_shape,
58+
initializer: kernel_initializer,
59+
trainable: true,
60+
dtype: _dtype);
5761
}
5862
}
5963
}

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ public Tensor __call__(Tensor inputs,
6868
throw new NotImplementedException("");
6969
}
7070

71-
protected virtual void add_weight()
71+
protected virtual void add_weight(string name,
72+
int[] shape,
73+
TF_DataType dtype = TF_DataType.DtInvalid,
74+
IInitializer initializer = null,
75+
bool? trainable = null)
7276
{
7377
var default_graph = ops.get_default_graph();
7478
Graph init_graph = null;
@@ -84,7 +88,9 @@ protected virtual void add_weight()
8488
existing_variables = variables.global_variables().ToArray();
8589
}
8690

87-
var dtype = TF_DataType.TF_FLOAT;
91+
if(dtype == TF_DataType.DtInvalid)
92+
dtype = TF_DataType.TF_FLOAT;
93+
8894
_set_scope();
8995
var reuse = built || (_reuse != null && _reuse.Value);
9096
Python.with(tf.variable_scope(_scope,
@@ -94,8 +100,19 @@ protected virtual void add_weight()
94100
_current_scope = scope;
95101
Python.with(ops.name_scope(_name_scope()), delegate
96102
{
97-
98-
103+
base.add_weight(name,
104+
shape,
105+
dtype: dtype,
106+
initializer: initializer,
107+
trainable: trainable,
108+
getter: (name1, shape1, dtype1, initializer1, trainable1) =>
109+
{
110+
return tf.get_variable(name1,
111+
shape: new TensorShape(shape1),
112+
dtype: dtype1,
113+
initializer: initializer1,
114+
trainable: trainable1);
115+
});
99116
});
100117
});
101118
}

src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45

56
namespace Tensorflow.Operations.Initializers
@@ -64,7 +65,16 @@ public Tensor call(TensorShape shape, TF_DataType dtype)
6465
if (shape.Length == 2)
6566
return (shape[0], shape[1]);
6667
else
67-
throw new NotImplementedException("VarianceScaling._compute_fans");
68+
{
69+
// Assuming convolution kernels (2D, 3D, or more).
70+
// kernel shape: (..., input_depth, depth)
71+
int receptive_field_size = 1;
72+
foreach (var dim in shape.Take(2))
73+
receptive_field_size *= dim;
74+
var fan_in = shape[shape.Length - 2] * receptive_field_size;
75+
var fan_out = shape[shape.Length - 1] * receptive_field_size;
76+
return (fan_in, fan_out);
77+
}
6878
}
6979

7080
public virtual object get_config()

src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,22 @@
44

55
namespace Tensorflow
66
{
7-
public class CheckpointableBase
7+
public abstract class CheckpointableBase
88
{
9+
/// <summary>
10+
/// Restore-on-create for a variable be saved with this `Checkpointable`.
11+
/// </summary>
12+
/// <returns></returns>
13+
protected virtual RefVariable _add_variable_with_custom_getter(string name,
14+
int[] shape,
15+
TF_DataType dtype = TF_DataType.TF_FLOAT,
16+
IInitializer initializer = null,
17+
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null,
18+
bool overwrite = false,
19+
bool trainable = false)
20+
{
21+
var new_variable = getter(name, shape, dtype, initializer, trainable);
22+
throw new NotImplementedException("_add_variable_with_custom_getter");
23+
}
924
}
1025
}

src/TensorFlowNET.Core/Variables/VariableScope.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public RefVariable get_variable(_VariableStore var_store,
4848
shape: shape,
4949
dtype: dtype,
5050
initializer: initializer,
51+
reuse: resue,
5152
trainable: trainable,
5253
synchronization: synchronization,
5354
aggregation: aggregation);

src/TensorFlowNET.Core/Variables/_VariableStore.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ public RefVariable get_variable(string name,
2424
TensorShape shape = null,
2525
TF_DataType dtype = TF_DataType.TF_FLOAT,
2626
object initializer = null, // IInitializer or Tensor
27+
bool? reuse = null,
2728
bool? trainable = null,
2829
bool validate_shape = true,
2930
VariableSynchronization synchronization = VariableSynchronization.AUTO,
@@ -100,7 +101,7 @@ private RefVariable _get_single_variable(string name,
100101
VariableSynchronization synchronization = VariableSynchronization.AUTO,
101102
VariableAggregation aggregation = VariableAggregation.NONE)
102103
{
103-
bool initializing_from_value = true;
104+
bool initializing_from_value = false;
104105
if (use_resource == null)
105106
use_resource = false;
106107

0 commit comments

Comments
 (0)