@@ -90,6 +90,19 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice {
9090 options.config .gpu_options ().force_gpu_compatible ();
9191 }
9292 }
93+ GPUCompatibleCPUDevice (const SessionOptions& options, const string& name,
94+ Bytes memory_limit, const DeviceLocality& locality,
95+ Allocator* allocator,
96+ const DeviceResourceMgrMap* dev_rmgr_map)
97+ : ThreadPoolDevice(options, name, memory_limit,
98+ locality, allocator, dev_rmgr_map),
99+ numa_node_(locality.numa_node()) {
100+ if (options.config .has_gpu_options ()) {
101+ force_gpu_compatible_ =
102+ options.config .gpu_options ().force_gpu_compatible ();
103+ }
104+ }
105+
93106 ~GPUCompatibleCPUDevice () override {}
94107
95108 Allocator* GetAllocator (AllocatorAttributes attr) override {
@@ -118,6 +131,12 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
118131
119132 Status CreateDevices (const SessionOptions& options, const string& name_prefix,
120133 std::vector<std::unique_ptr<Device>>* devices) override {
134+ return CreateDevices (options, name_prefix, devices, nullptr );
135+ }
136+
137+ Status CreateDevices (const SessionOptions& options, const string& name_prefix,
138+ std::vector<std::unique_ptr<Device>>* devices,
139+ const DeviceResourceMgrMap* dev_rmgr_map) override {
121140 int n = 1 ;
122141 auto iter = options.config .device_count ().find (" CPU" );
123142 if (iter != options.config .device_count ().end ()) {
@@ -133,7 +152,8 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
133152 locality.set_numa_node (numa_node);
134153 devices->push_back (absl::make_unique<GPUCompatibleCPUDevice>(
135154 options, name, Bytes (256 << 20 ), DeviceLocality (),
136- ProcessState::singleton ()->GetCPUAllocator (numa_node)));
155+ ProcessState::singleton ()->GetCPUAllocator (numa_node),
156+ dev_rmgr_map));
137157 }
138158
139159 return Status::OK ();
0 commit comments