Skip to content

Commit 0ca9485

Browse files
committed
as_base_type #136
1 parent 0511614 commit 0ca9485

File tree

5 files changed

+14
-11
lines changed

5 files changed

+14
-11
lines changed

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
using System.Linq;
44
using System.Runtime.InteropServices;
55
using System.Text;
6-
using TF_DataType = Tensorflow.DataType;
76

87
namespace Tensorflow
98
{

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,8 @@ public Operation _apply_op_helper(string op_type_name, string name = "", Diction
7272
}
7373
else
7474
{
75-
var base_type = value.dtype;
76-
// base type
77-
if ((int)value.dtype > 100)
78-
{
79-
base_type = (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)value.dtype - 100).ToString());
80-
}
75+
var base_type = value.dtype.as_base_dtype();
76+
8177
input_types.Add(base_type);
8278
}
8379
}
@@ -151,7 +147,7 @@ public Operation _apply_op_helper(string op_type_name, string name = "", Diction
151147

152148
public DataType _MakeType(TF_DataType v, AttrDef attr_def)
153149
{
154-
return v.as_datatype_enum();
150+
return v.as_base_dtype().as_datatype_enum();
155151
}
156152
}
157153
}

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ public static Type as_numpy_datatype(this TF_DataType type)
2424
throw new NotImplementedException("as_numpy_datatype failed");
2525
}
2626
}
27+
2728
public static TF_DataType as_dtype(Type type)
2829
{
2930
TF_DataType dtype = TF_DataType.DtInvalid;
@@ -62,5 +63,12 @@ public static DataType as_datatype_enum(this TF_DataType type)
6263

6364
return dtype;
6465
}
66+
67+
public static TF_DataType as_base_dtype(this TF_DataType type)
68+
{
69+
return (int)type > 100 ?
70+
(TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type - 100).ToString()) :
71+
type;
72+
}
6573
}
6674
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public partial class RefVariable : VariableV1
1616

1717
private Operation _initializer_op;
1818
public Operation initializer => _initializer_op;
19-
public Operation op => _initializer_op;
19+
public Operation op => _variable.op;
2020

2121
public string name => _variable.name;
2222

@@ -77,7 +77,7 @@ private void _init_from_args(object initial_value,
7777

7878
var shape = _initial_value.shape;
7979
dtype = _initial_value.dtype;
80-
_variable = gen_state_ops.variable_v2(shape, dtype, name);
80+
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), name);
8181
}
8282

8383
// Manually overrides the variable's shape with the initial value's.

test/TensorFlowNET.UnitTest/VariableTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public void ScalarVar()
2929
[TestMethod]
3030
public void Add()
3131
{
32-
var x = tf.Variable(0, name: "x");
32+
var x = tf.Variable(10, name: "x");
3333

3434
var model = tf.global_variables_initializer();
3535

0 commit comments

Comments
 (0)