Skip to content

Commit c3dd96b

Browse files
committed
Use tensor.Id instead of GetHashCode.
1 parent 9ff09c4 commit c3dd96b

File tree

4 files changed

+28
-33
lines changed

4 files changed

+28
-33
lines changed

src/TensorFlowNET.Core/Binding.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ namespace Tensorflow
44
{
55
public static partial class Binding
66
{
7-
[DebuggerNonUserCode]
7+
[DebuggerHidden]
88
public static tensorflow tf { get; } = New<tensorflow>();
99

1010
/// <summary>

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace Tensorflow.Eager
66
{
7-
public partial class EagerTensor : Tensor
7+
public partial class EagerTensor
88
{
99
public EagerTensor() : base(IntPtr.Zero)
1010
{
@@ -48,8 +48,8 @@ public EagerTensor Resolve()
4848
if (_handle == IntPtr.Zero)
4949
_handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.Status.Handle);
5050

51-
//print($"new Tensor {Id} {_handle.ToString("x16")}");
52-
//print($"new TensorHandle {Id} {EagerTensorHandle.ToString("x16")}");
51+
// print($"New TensorHandle {Id} 0x{_handle.ToString("x16")}");
52+
// print($"New EagerTensorHandle {Id} {EagerTensorHandle}");
5353

5454
return this;
5555
}
@@ -96,14 +96,14 @@ protected override void DisposeManagedResources()
9696
{
9797
base.DisposeManagedResources();
9898

99-
//print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}");
99+
// print($"Delete EagerTensorHandle {Id} {EagerTensorHandle}");
100100
EagerTensorHandle.Dispose();
101101
}
102102

103103
protected override void DisposeUnmanagedResources(IntPtr handle)
104104
{
105105
base.DisposeUnmanagedResources(handle);
106-
//print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}");
106+
// print($"Delete TensorHandle {Id} 0x{_handle.ToString("x16")}");
107107
}
108108
}
109109
}

src/TensorFlowNET.Keras/Engine/Functional.cs

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ public partial class Functional : Model
2323
List<KerasHistory> _output_coordinates;
2424
public string[] NetworkNodes { get; set; }
2525

26-
Dictionary<int, int> tensor_usage_count;
27-
public Dictionary<int, int> TensorUsageCount => tensor_usage_count;
26+
Dictionary<long, int> tensor_usage_count;
2827

2928
public Functional(Tensors inputs, Tensors outputs, string name = null)
3029
: base(new ModelArgs
@@ -38,7 +37,7 @@ public Functional(Tensors inputs, Tensors outputs, string name = null)
3837
_output_layers = new List<ILayer>();
3938
_input_coordinates = new List<KerasHistory>();
4039
_output_coordinates = new List<KerasHistory>();
41-
tensor_usage_count = new Dictionary<int, int>();
40+
tensor_usage_count = new Dictionary<long, int>();
4241
if (this is Sequential)
4342
return;
4443
_init_graph_network(inputs, outputs);
@@ -116,33 +115,33 @@ void _set_output_names()
116115

117116
void ComputeTensorUsageCount()
118117
{
119-
var available_tensors = inputs.Select(x => x.GetHashCode()).ToList();
118+
var available_tensors = inputs.Select(x => x.Id).ToList();
120119
var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray();
121120
foreach (var depth in depth_keys)
122121
{
123122
foreach (var node in NodesByDepth[depth])
124123
{
125-
var input_tensors = node.KerasInputs.Select(x => x.GetHashCode()).ToArray();
124+
var input_tensors = node.KerasInputs.Select(x => x.Id).ToArray();
126125
if (input_tensors.issubset(available_tensors))
127126
{
128127
foreach (var tensor in node.KerasInputs)
129128
{
130-
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode()))
131-
tensor_usage_count[tensor.GetHashCode()] = 0;
132-
tensor_usage_count[tensor.GetHashCode()] += 1;
129+
if (!tensor_usage_count.ContainsKey(tensor.Id))
130+
tensor_usage_count[tensor.Id] = 0;
131+
tensor_usage_count[tensor.Id] += 1;
133132
}
134133

135134
foreach (var output_tensor in node.Outputs)
136-
available_tensors.Add(output_tensor.GetHashCode());
135+
available_tensors.Add(output_tensor.Id);
137136
}
138137
}
139138
}
140139

141140
foreach (var tensor in outputs)
142141
{
143-
if (!tensor_usage_count.ContainsKey(tensor.GetHashCode()))
144-
tensor_usage_count[tensor.GetHashCode()] = 0;
145-
tensor_usage_count[tensor.GetHashCode()] += 1;
142+
if (!tensor_usage_count.ContainsKey(tensor.Id))
143+
tensor_usage_count[tensor.Id] = 0;
144+
tensor_usage_count[tensor.Id] += 1;
146145
}
147146
}
148147

@@ -316,12 +315,11 @@ Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask =
316315
input_t.KerasMask = masks[i];
317316
}
318317

319-
var tensor_dict = new Dictionary<int, Queue<Tensor>>();
318+
var tensor_dict = new Dictionary<long, Queue<Tensor>>();
320319
foreach (var (x, y) in zip(this.inputs, inputs))
321320
{
322321
var y1 = conform_to_reference_input(y, x);
323-
var x_id = x.GetHashCode();
324-
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1));
322+
tensor_dict[x.Id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x.Id]).Select(x => y1));
325323
}
326324

327325
var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray();
@@ -347,13 +345,10 @@ Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask =
347345
}
348346
}
349347

350-
var output_tensors = new List<Tensor>();
348+
var output_tensors = new Tensors();
351349

352350
foreach (var x in outputs)
353-
{
354-
var x_id = x.GetHashCode();
355-
output_tensors.append(tensor_dict[x_id].Dequeue());
356-
}
351+
output_tensors.Add(tensor_dict[x.Id].Dequeue());
357352

358353
return output_tensors;
359354
}

src/TensorFlowNET.Keras/Engine/Node.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ public partial class Node : INode
4242
public List<Tensor> KerasInputs { get; set; } = new List<Tensor>();
4343
public ILayer Layer { get; set; }
4444
public bool is_input => args.InputTensors == null;
45-
public int[] FlatInputIds { get; set; }
46-
public int[] FlatOutputIds { get; set; }
45+
public long[] FlatInputIds { get; set; }
46+
public long[] FlatOutputIds { get; set; }
4747
bool _single_positional_tensor_passed => KerasInputs.Count() == 1;
48-
Dictionary<int, int> _keras_inputs_ids_and_indices = new Dictionary<int, int>();
48+
Dictionary<int, long> _keras_inputs_ids_and_indices = new Dictionary<int, long>();
4949
public INode[] ParentNodes
5050
{
5151
get
@@ -70,7 +70,7 @@ public Node(Layer layer, NodeArgs args)
7070
KerasInputs.AddRange(args.InputTensors);
7171

7272
foreach (var (i, ele) in enumerate(KerasInputs))
73-
_keras_inputs_ids_and_indices[i] = ele.GetHashCode();
73+
_keras_inputs_ids_and_indices[i] = ele.Id;
7474

7575
// Wire up Node to Layers.
7676
layer.InboundNodes.Add(this);
@@ -89,16 +89,16 @@ public Node(Layer layer, NodeArgs args)
8989
tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor);
9090

9191
// Cached for performance.
92-
FlatInputIds = KerasInputs.Select(x => x.GetHashCode()).ToArray();
93-
FlatOutputIds = Outputs.Select(x => x.GetHashCode()).ToArray();
92+
FlatInputIds = KerasInputs.Select(x => x.Id).ToArray();
93+
FlatOutputIds = Outputs.Select(x => x.Id).ToArray();
9494
}
9595

9696
/// <summary>
9797
/// Maps Keras Tensors to computed Tensors using `tensor_dict`.
9898
/// </summary>
9999
/// <param name="tensor_dict"></param>
100100
/// <returns></returns>
101-
public Tensors MapArguments(Dictionary<int, Queue<Tensor>> tensor_dict)
101+
public Tensors MapArguments(Dictionary<long, Queue<Tensor>> tensor_dict)
102102
{
103103
if (_single_positional_tensor_passed)
104104
{

0 commit comments

Comments
 (0)