Skip to content

Commit a8f6308

Browse files
authored
[Runtime] Share resource between GPUCompatibleCPUDevice. (#285)
1 parent 24e37d7 commit a8f6308

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

tensorflow/core/common_runtime/direct_session.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,10 @@ class DirectSessionFactory : public SessionFactory {
274274
ResourceMgr* shared_rmgr = new ResourceMgr("localhost");
275275
DeviceResourceMgrMap dev_rmgr_map;
276276
std::string dev_prefix("/job:localhost/replica:0/task:0");
277-
for (int i = 0; i < session_num; ++i) {
278-
std::string dev_name = dev_prefix + "/device:CPU:" + std::to_string(i);
279-
dev_rmgr_map.device_rmgr_map[dev_name] = shared_rmgr;
280-
dev_name = dev_prefix + "/device:cpu:" + std::to_string(i);
281-
dev_rmgr_map.device_rmgr_map[dev_name] = shared_rmgr;
282-
}
277+
dev_rmgr_map.device_rmgr_map[dev_prefix+"/device:CPU:0"] = shared_rmgr;
278+
dev_rmgr_map.device_rmgr_map[dev_prefix+"/device:cpu:0"] = shared_rmgr;
279+
dev_rmgr_map.device_rmgr_map["/device:CPU:0"] = shared_rmgr;
280+
dev_rmgr_map.device_rmgr_map["/device:cpu:0"] = shared_rmgr;
283281

284282
std::vector<std::unique_ptr<Device>> devices;
285283
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(

tensorflow/core/common_runtime/gpu/gpu_device_factory.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice {
9090
options.config.gpu_options().force_gpu_compatible();
9191
}
9292
}
93+
GPUCompatibleCPUDevice(const SessionOptions& options, const string& name,
94+
Bytes memory_limit, const DeviceLocality& locality,
95+
Allocator* allocator,
96+
const DeviceResourceMgrMap* dev_rmgr_map)
97+
: ThreadPoolDevice(options, name, memory_limit,
98+
locality, allocator, dev_rmgr_map),
99+
numa_node_(locality.numa_node()) {
100+
if (options.config.has_gpu_options()) {
101+
force_gpu_compatible_ =
102+
options.config.gpu_options().force_gpu_compatible();
103+
}
104+
}
105+
93106
~GPUCompatibleCPUDevice() override {}
94107

95108
Allocator* GetAllocator(AllocatorAttributes attr) override {
@@ -118,6 +131,12 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
118131

119132
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
120133
std::vector<std::unique_ptr<Device>>* devices) override {
134+
return CreateDevices(options, name_prefix, devices, nullptr);
135+
}
136+
137+
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
138+
std::vector<std::unique_ptr<Device>>* devices,
139+
const DeviceResourceMgrMap* dev_rmgr_map) override {
121140
int n = 1;
122141
auto iter = options.config.device_count().find("CPU");
123142
if (iter != options.config.device_count().end()) {
@@ -133,7 +152,8 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
133152
locality.set_numa_node(numa_node);
134153
devices->push_back(absl::make_unique<GPUCompatibleCPUDevice>(
135154
options, name, Bytes(256 << 20), DeviceLocality(),
136-
ProcessState::singleton()->GetCPUAllocator(numa_node)));
155+
ProcessState::singleton()->GetCPUAllocator(numa_node),
156+
dev_rmgr_map));
137157
}
138158

139159
return Status::OK();

0 commit comments

Comments
 (0)