44// This source code is licensed under the BSD-style license found in the
55// LICENSE file in the root directory of this source tree.
66
7+ #pragma once
8+
79#include < torch/types.h>
810#include < memory>
911#include < mutex>
@@ -27,7 +29,7 @@ class Cache {
2729 public:
2830 using element_type = std::unique_ptr<T, D>;
2931
30- Cache (int capacity) : capacity_(capacity) {}
32+ explicit Cache (int capacity) : capacity_(capacity) {}
3133
3234 // Adds an object to the cache if the cache has capacity. Returns true
3335 // if object was added and false otherwise.
@@ -56,8 +58,9 @@ bool Cache<T, D>::addIfCacheHasCapacity(element_type&& obj) {
5658template <typename T, typename D>
5759typename Cache<T, D>::element_type Cache<T, D>::get() {
5860 std::scoped_lock lock (mutex_);
59- if (cache_.empty ())
61+ if (cache_.empty ()) {
6062 return nullptr ;
63+ }
6164
6265 element_type obj = std::move (cache_.back ());
6366 cache_.pop_back ();
@@ -92,7 +95,15 @@ class PerGpuCache {
9295 std::vector<std::unique_ptr<Cache<T, D>>> cache_;
9396};
9497
95- torch::DeviceIndex getNonNegativeDeviceIndex (const torch::Device& device) {
98+ // Note: this function is inline for convenience, not performance. Because the
99+ // rest of this file is template functions, they must all be defined in this
100+ // header. This function is not a template function, and should, in principle,
101+ // be defined in a .cpp file to preserve the One Definition Rule. That's
102+ // annoying for such a small amount of code, so we just inline it. If this file
103+ // grows, and there are more such functions, we should break them out into a
104+ // .cpp file.
105+ inline torch::DeviceIndex getNonNegativeDeviceIndex (
106+ const torch::Device& device) {
96107 torch::DeviceIndex deviceIndex = device.index ();
97108 // For single GPU machines libtorch returns -1 for the device index. So for
98109 // that case we set the device index to 0. That's used in per-gpu cache
0 commit comments