@@ -95,30 +95,16 @@ class PerGpuCache {
9595 std::vector<std::unique_ptr<Cache<T, D>>> cache_;
9696};
9797
98- // Note: this function is inline for convenience, not performance. Because the
99- // rest of this file is template functions, they must all be defined in this
100- // header. This function is not a template function, and should, in principle,
101- // be defined in a .cpp file to preserve the One Definition Rule. That's
102- // annoying for such a small amount of code, so we just inline it. If this file
103- // grows, and there are more such functions, we should break them out into a
104- // .cpp file.
105- inline torch::DeviceIndex getNonNegativeDeviceIndex (
106- const torch::Device& device) {
107- torch::DeviceIndex deviceIndex = device.index ();
108- // For single GPU machines libtorch returns -1 for the device index. So for
109- // that case we set the device index to 0. That's used in per-gpu cache
110- // implementation and during initialization of CUDA and FFmpeg contexts
111- // which require non negative indices.
112- deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0 );
113- TORCH_CHECK (deviceIndex >= 0 , " Device index out of range" );
114- return deviceIndex;
115- }
98+ // Forward declaration of getDeviceIndex which exists in CUDACommon.h
99+ // This avoids circular dependency between Cache.h and CUDACommon.cpp which also
100+ // needs to include Cache.h
101+ int getDeviceIndex (const torch::Device& device);
116102
117103template <typename T, typename D>
118104bool PerGpuCache<T, D>::addIfCacheHasCapacity(
119105 const torch::Device& device,
120106 element_type&& obj) {
121- torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex (device);
107+ int deviceIndex = getDeviceIndex (device);
122108 TORCH_CHECK (
123109 static_cast <size_t >(deviceIndex) < cache_.size (),
124110 " Device index out of range" );
@@ -128,7 +114,7 @@ bool PerGpuCache<T, D>::addIfCacheHasCapacity(
128114template <typename T, typename D>
129115typename PerGpuCache<T, D>::element_type PerGpuCache<T, D>::get(
130116 const torch::Device& device) {
131- torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex (device);
117+ int deviceIndex = getDeviceIndex (device);
132118 TORCH_CHECK (
133119 static_cast <size_t >(deviceIndex) < cache_.size (),
134120 " Device index out of range" );
0 commit comments