@@ -40,7 +40,12 @@ public virtual NDArray run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict
4040 return _run ( fetches , feed_dict ) ;
4141 }
4242
43- private NDArray _run ( Tensor fetches , Dictionary < Tensor , NDArray > feed_dict = null )
43+ public virtual NDArray run ( Operation fetches , Dictionary < Tensor , NDArray > feed_dict = null )
44+ {
45+ return _run ( fetches , feed_dict ) ;
46+ }
47+
48+ private NDArray _run < T > ( T fetches , Dictionary < Tensor , NDArray > feed_dict = null )
4449 {
4550 var feed_dict_tensor = new Dictionary < Tensor , NDArray > ( ) ;
4651
@@ -53,7 +58,7 @@ private NDArray _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = nul
5358 }
5459
5560 // Create a fetch handler to take care of the structure of fetches.
56- var fetch_handler = new _FetchHandler ( _graph , fetches , feed_dict_tensor ) ;
61+ var fetch_handler = new _FetchHandler < T > ( _graph , fetches , feed_dict_tensor ) ;
5762
5863 // Run request and get response.
5964 // We need to keep the returned movers alive for the following _do_run().
@@ -65,20 +70,34 @@ private NDArray _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = nul
6570
6671 // We only want to really perform the run if fetches or targets are provided,
6772 // or if the call is a partial run that specifies feeds.
68- var results = _do_run ( final_fetches , feed_dict_tensor ) ;
73+ var results = _do_run ( final_targets . Select ( x => ( Operation ) ( object ) x ) . ToList ( ) , final_fetches , feed_dict_tensor ) ;
6974
7075 return fetch_handler . build_results ( null , results ) ;
7176 }
7277
73- private NDArray [ ] _do_run ( List < Tensor > fetch_list , Dictionary < Tensor , NDArray > feed_dict )
78+ /// <summary>
79+ /// Runs a step based on the given fetches and feeds.
80+ /// </summary>
81+ /// <typeparam name="T"></typeparam>
82+ /// <param name="target_list">A list of operations to be run, but not fetched.</param>
83+ /// <param name="fetch_list"></param>
84+ /// <param name="feed_dict"></param>
85+ /// <returns>
86+ /// A list of numpy ndarrays, corresponding to the elements of
87+ /// `fetch_list`. If the ith element of `fetch_list` contains the
88+ /// name of an operation, the first Tensor output of that operation
89+ /// will be returned for that element.
90+ /// </returns>
91+ private NDArray [ ] _do_run ( List < Operation > target_list , List < Tensor > fetch_list , Dictionary < Tensor , NDArray > feed_dict )
7492 {
7593 var feeds = feed_dict . Select ( x => new KeyValuePair < TF_Output , Tensor > ( x . Key . _as_tf_output ( ) , new Tensor ( x . Value ) ) ) . ToArray ( ) ;
7694 var fetches = fetch_list . Select ( x => x . _as_tf_output ( ) ) . ToArray ( ) ;
95+ var targets = target_list ;
7796
78- return _call_tf_sessionrun ( feeds , fetches ) ;
97+ return _call_tf_sessionrun ( feeds , fetches , target_list ) ;
7998 }
8099
81- private unsafe NDArray [ ] _call_tf_sessionrun ( KeyValuePair < TF_Output , Tensor > [ ] feed_dict , TF_Output [ ] fetch_list )
100+ private unsafe NDArray [ ] _call_tf_sessionrun ( KeyValuePair < TF_Output , Tensor > [ ] feed_dict , TF_Output [ ] fetch_list , List < Operation > target_list )
82101 {
83102 // Ensure any changes to the graph are reflected in the runtime.
84103 _extend_graph ( ) ;
@@ -95,8 +114,8 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] f
95114 outputs : fetch_list ,
96115 output_values : output_values ,
97116 noutputs : fetch_list . Length ,
98- target_opers : IntPtr . Zero ,
99- ntargets : 0 ,
117+ target_opers : target_list . Select ( f => ( IntPtr ) f ) . ToArray ( ) ,
118+ ntargets : target_list . Count ,
100119 run_metadata : IntPtr . Zero ,
101120 status : status ) ;
102121
0 commit comments