Skip to content

Commit 6794925

Browse files
committed
override graph
1 parent 1ef2ec1 commit 6794925

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ public class _ElementFetchMapper : _FetchMapper
2828
{
2929
private Func<List<NDArray>, object> _contraction_fn;
3030

31-
public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn)
31+
public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn, Graph graph = null)
3232
{
33-
var g = ops.get_default_graph();
33+
var g = graph ?? ops.get_default_graph();
3434

3535
foreach(var fetch in fetches)
3636
{

src/TensorFlowNET.Core/Sessions/_FetchHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public class _FetchHandler
3434

3535
public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
3636
{
37-
_fetch_mapper = _FetchMapper.for_fetch(fetches);
37+
_fetch_mapper = _FetchMapper.for_fetch(fetches, graph: graph);
3838
foreach(var fetch in _fetch_mapper.unique_fetches())
3939
{
4040
switch (fetch)

src/TensorFlowNET.Core/Sessions/_FetchMapper.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class _FetchMapper
2525
{
2626
protected List<ITensorOrOperation> _unique_fetches = new List<ITensorOrOperation>();
2727
protected List<int[]> _value_indices = new List<int[]>();
28-
public static _FetchMapper for_fetch(object fetch)
28+
public static _FetchMapper for_fetch(object fetch, Graph graph = null)
2929
{
3030
var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch };
3131

@@ -34,7 +34,7 @@ public static _FetchMapper for_fetch(object fetch)
3434
if (fetch.GetType().IsArray)
3535
return new _ListFetchMapper(fetches);
3636
else
37-
return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0]);
37+
return new _ElementFetchMapper(fetches, (List<NDArray> fetched_vals) => fetched_vals[0], graph: graph);
3838
}
3939

4040
public virtual NDArray[] build_results(List<NDArray> values)

0 commit comments

Comments
 (0)