Skip to content

Commit 08862d2

Browse files
hcurEsther2013
authored andcommitted
replace usage of ops.name_scope() with tf_with()
1 parent c233e70 commit 08862d2

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -232,18 +232,19 @@ public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null)
232232

233233
public static Tensor real(Tensor input, string name = null)
234234
{
235-
using (var name_ = ops.name_scope(name, "Real", new [] {input}))
235+
return tf_with(ops.name_scope(name, "Real", new [] {input}), scope =>
236236
{
237+
// name = scope;
237238
input = ops.convert_to_tensor(input, name: "input");
238239
if (input.dtype.is_complex())
239240
{
240241
var real_dtype = input.dtype.real_dtype();
241-
return real(input, name: name);
242+
return real(input, name: scope);
242243
} else
243244
{
244245
return input;
245246
}
246-
}
247+
});
247248
}
248249

249250
/// <summary>
@@ -317,11 +318,11 @@ public static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool kee
317318
name = "reduce_std";
318319
// else {name = name;}
319320

320-
using (ops.name_scope(name))
321+
return tf_with(ops.name_scope(name, "reduce_std", new [] {input_tensor}), scope =>
321322
{
322323
var variance = reduce_variance(input_tensor, axis: axis, keepdims: keepdims);
323324
return gen_math_ops.sqrt(variance);
324-
}
325+
});
325326
}
326327

327328
public static Tensor reduce_variance(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
@@ -330,7 +331,7 @@ public static Tensor reduce_variance(Tensor input_tensor, int[] axis = null, boo
330331
name = "reduce_variance";
331332
// else {name = name;}
332333

333-
using (ops.name_scope(name))
334+
return tf_with(ops.name_scope(name, "reduce_variance", new [] {input_tensor}), scope =>
334335
{
335336
var means = reduce_mean(input_tensor, axis: axis, keepdims: true);
336337
if (means.dtype.is_integer())
@@ -348,7 +349,7 @@ public static Tensor reduce_variance(Tensor input_tensor, int[] axis = null, boo
348349
squared_deviations = gen_math_ops.square(diff);
349350
}
350351
return reduce_mean(squared_deviations, axis: axis, keepdims: keepdims);
351-
}
352+
});
352353
}
353354

354355
public static Tensor sigmoid<T>(T x, string name = null)

0 commit comments

Comments
 (0)