Skip to content

Commit 0511614

Browse files
committed
ops.colocate_with
1 parent 444cc42 commit 0511614

File tree

5 files changed

+36
-1
lines changed

5 files changed

+36
-1
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public partial class Graph
8+
{
9+
public void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false)
10+
{
11+
12+
}
13+
}
14+
}

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,13 @@ public Operation _apply_op_helper(string op_type_name, string name = "", Diction
7272
}
7373
else
7474
{
75-
input_types.Add(value.dtype);
75+
var base_type = value.dtype;
76+
// base type
77+
if ((int)value.dtype > 100)
78+
{
79+
base_type = (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)value.dtype - 100).ToString());
80+
}
81+
input_types.Add(base_type);
7682
}
7783
}
7884
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ public InputList inputs
124124
}
125125
}
126126

127+
private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray();
128+
127129
private NodeDef _node_def;
128130
public NodeDef node_def
129131
{

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ private void _init_from_args(object initial_value,
9999
}
100100
else
101101
{
102+
ops.colocate_with(_initializer_op);
103+
102104
_snapshot = gen_array_ops.identity(_variable, name = "read");
103105
}
104106

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,5 +185,16 @@ public static int uid()
185185
{
186186
return uid_number++;
187187
}
188+
189+
public static void colocate_with(Operation op, bool ignore_existing = false)
190+
{
191+
_colocate_with_for_gradient(op, null, ignore_existing);
192+
}
193+
194+
private static void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false)
195+
{
196+
var default_graph = get_default_graph();
197+
default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing);
198+
}
188199
}
189200
}

0 commit comments

Comments
 (0)