@@ -20,6 +20,8 @@ limitations under the License.
2020using Tensorflow . Eager ;
2121using static Tensorflow . Binding ;
2222using Google . Protobuf ;
23+ using Tensorflow . Device ;
24+ using System . Collections . Generic ;
2325
2426namespace Tensorflow . Contexts
2527{
@@ -30,6 +32,7 @@ public sealed partial class Context
3032 {
3133 ContextDevicePlacementPolicy _device_policy ;
3234 bool _log_device_placement ;
35+ Dictionary < PhysicalDevice , bool > _memory_growth_map = new Dictionary < PhysicalDevice , bool > ( ) ;
3336
3437 public void log_device_placement ( bool enable )
3538 {
@@ -38,5 +41,53 @@ public void log_device_placement(bool enable)
3841 _log_device_placement = enable ;
3942 // _thread_local_data.function_call_options = null;
4043 }
44+
45+ public bool get_memory_growth ( string device_type )
46+ {
47+ foreach ( var map in _memory_growth_map )
48+ {
49+ if ( map . Key . DeviceType == device_type )
50+ return map . Value ;
51+ }
52+ return false ;
53+ }
54+
55+ public void set_memory_growth ( PhysicalDevice device , bool enable )
56+ {
57+ _memory_growth_map [ device ] = enable ;
58+ }
59+
60+ public PhysicalDevice [ ] list_physical_devices ( string device_type = null )
61+ {
62+ using var opts = c_api . TFE_NewContextOptions ( ) ;
63+ using var ctx = c_api . TFE_NewContext ( opts , tf . Status . Handle ) ;
64+ using var devices = c_api . TFE_ContextListDevices ( ctx , tf . Status . Handle ) ;
65+ tf . Status . Check ( true ) ;
66+
67+ int num_devices = c_api . TF_DeviceListCount ( devices ) ;
68+ var results = new List < PhysicalDevice > ( ) ;
69+ for ( int i = 0 ; i < num_devices ; ++ i )
70+ {
71+ var dev_type = c_api . StringPiece ( c_api . TF_DeviceListType ( devices , i , tf . Status . Handle ) ) ;
72+ tf . Status . Check ( true ) ;
73+
74+ if ( dev_type . StartsWith ( "XLA" ) )
75+ continue ;
76+
77+ if ( device_type == null || dev_type == device_type )
78+ {
79+ var dev_name = c_api . TF_DeviceListName ( devices , i , tf . Status . Handle ) ;
80+ tf . Status . Check ( true ) ;
81+
82+ results . Add ( new PhysicalDevice
83+ {
84+ DeviceName = dev_name ,
85+ DeviceType = dev_type
86+ } ) ;
87+ }
88+ }
89+
90+ return results . ToArray ( ) ;
91+ }
4192 }
4293}
0 commit comments