diff --git a/torchmdnet/extensions/neighbors/common.cuh b/torchmdnet/extensions/neighbors/common.cuh index 375d9b5a8..6717d4c27 100644 --- a/torchmdnet/extensions/neighbors/common.cuh +++ b/torchmdnet/extensions/neighbors/common.cuh @@ -10,6 +10,8 @@ #include #include +using at::BFloat16; +using at::Half; using c10::cuda::CUDAStreamGuard; using c10::cuda::getCurrentCUDAStream; using torch::empty; @@ -23,6 +25,9 @@ using torch::autograd::AutogradContext; using torch::autograd::Function; using torch::autograd::tensor_list; +#define DISPATCH_FOR_ALL_FLOAT_TYPES(...)\ + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, __VA_ARGS__)\ + template using Accessor = torch::PackedTensorAccessor32; @@ -34,14 +39,6 @@ inline Accessor get_accessor(const Tensor& tensor) { return tensor.packed_accessor32(); }; -template __device__ __forceinline__ scalar_t sqrt_(scalar_t x){}; -template <> __device__ __forceinline__ float sqrt_(float x) { - return ::sqrtf(x); -}; -template <> __device__ __forceinline__ double sqrt_(double x) { - return ::sqrt(x); -}; - template struct vec3 { using type = void; }; @@ -54,6 +51,22 @@ template <> struct vec3 { using type = double3; }; +struct Half3 { + Half x, y, z; +}; + +template <> struct vec3 { + using type = Half3; +}; + +struct BFloat163 { + BFloat16 x, y, z; +}; + +template <> struct vec3 { + using type = BFloat163; +}; + template using scalar3 = typename vec3::type; /* @@ -194,12 +207,12 @@ __device__ auto apply_pbc(scalar3 delta, const KernelAccessor +template __device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j, bool use_periodic, const KernelAccessor& box) { scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z}; if (use_periodic) { - delta = apply_pbc(delta, box); + delta = apply_pbc(delta, box); } return delta; } diff --git a/torchmdnet/extensions/neighbors/neighbors_cpu.cpp b/torchmdnet/extensions/neighbors/neighbors_cpu.cpp index 86e0578aa..4848fa5ad 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/extensions/neighbors/neighbors_cpu.cpp @@ -95,7 +95,7 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& in_box_vecto deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) - scale1 * box_vectors.index({pair_batch, 0, 0})); } - distances = frobenius_norm(deltas, 1); + distances = torch::linalg::norm(deltas, c10::nullopt, 1, false, c10::nullopt); mask = (distances < cutoff_upper) * (distances >= cutoff_lower); neighbors = neighbors.index({Slice(), mask}); deltas = deltas.index({mask, Slice()}); diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh index 7df3317d4..bd6a53f77 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh @@ -40,7 +40,7 @@ __global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor= cutoff_lower2) { - const scalar_t r2 = sqrt_(distance2); + const scalar_t r2 = ::sqrt(distance2); addAtomPairToList(list, row, column, delta, r2, list.include_transpose); } } @@ -100,7 +100,7 @@ forward_brute(const Tensor& positions, const Tensor& batch, const Tensor& in_box const uint64_t num_all_pairs = num_atoms * (num_atoms - 1UL) / 2UL; const uint64_t num_threads = 128; const uint64_t num_blocks = std::max((num_all_pairs + num_threads - 1UL) / num_threads, 1UL); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { + DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { PairListAccessor list_accessor(list); auto box = triclinic::get_box_accessor(box_vectors, use_periodic); const scalar_t cutoff_upper_ = cutoff_upper.to(); diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh index 193a79898..987905253 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh @@ -116,7 +116,7 @@ static auto sortAtomsByCellIndex(const Tensor& positions, const Tensor& box_size const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "assignHash", [&] { + DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "assignHash", [&] { scalar_t cutoff_ = cutoff.to(); scalar3 box_size_ = {box_size[0][0].item(), box_size[1][1].item(), @@ -229,7 +229,7 @@ CellList constructCellList(const Tensor& positions, const Tensor& batch, const T cl.sorted_batch = batch.index_select(0, cl.sorted_indices); // Step 3 int3 cell_dim; - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "computeCellDim", [&] { + DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "computeCellDim", [&] { scalar_t cutoff_ = cutoff.to(); scalar3 box_size_ = {box_size[0][0].item(), box_size[1][1].item(), @@ -270,7 +270,7 @@ __device__ void addNeighborPair(PairListAccessor& list, const int i, c const int ni = max(i, j); const int nj = min(i, j); const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0); - const scalar_t distance = sqrt_(distance2); + const scalar_t distance = ::sqrt(distance2); delta = {delta_sign * delta.x, delta_sign * delta.y, delta_sign * delta.z}; addAtomPairToList(list, ni, nj, delta, distance, requires_transpose); } @@ -368,7 +368,7 @@ forward_cell(const Tensor& positions, const Tensor& batch, const Tensor& in_box_ const auto stream = getCurrentCUDAStream(positions.get_device()); { // Traverse the cell list to find the neighbors const CUDAStreamGuard guard(stream); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { + DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "forward", [&] { const scalar_t cutoff_upper_ = cutoff_upper.to(); TORCH_CHECK(cutoff_upper_ > 0, "Expected cutoff_upper to be positive"); const scalar_t cutoff_lower_ = cutoff_lower.to(); diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh index 9c4523f50..d459da119 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh @@ -62,7 +62,7 @@ __global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor= cutoff_lower2) { const bool requires_transpose = list.include_transpose && !(cur_j == id); - const auto distance = sqrt_(distance2); + const scalar_t distance = ::sqrt(distance2); addAtomPairToList(list, id, cur_j, delta, distance, requires_transpose); } } @@ -104,21 +104,22 @@ forward_shared(const Tensor& positions, const Tensor& batch, const Tensor& in_bo const auto stream = getCurrentCUDAStream(positions.get_device()); PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic); const CUDAStreamGuard guard(stream); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { - const scalar_t cutoff_upper_ = cutoff_upper.to(); - const scalar_t cutoff_lower_ = cutoff_lower.to(); - auto box = triclinic::get_box_accessor(box_vectors, use_periodic); - TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); - constexpr int BLOCKSIZE = 64; - const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); - const int num_threads = BLOCKSIZE; - const int num_tiles = num_blocks; - PairListAccessor list_accessor(list); - forward_kernel_shared<<>>( - num_atoms, get_accessor(positions), get_accessor(batch), - cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, list_accessor, num_tiles, - box); - }); + DISPATCH_FOR_ALL_FLOAT_TYPES( + positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { + const scalar_t cutoff_upper_ = cutoff_upper.to(); + const scalar_t cutoff_lower_ = cutoff_lower.to(); + auto box = triclinic::get_box_accessor(box_vectors, use_periodic); + TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); + constexpr int BLOCKSIZE = 64; + const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); + const int num_threads = BLOCKSIZE; + const int num_tiles = num_blocks; + PairListAccessor list_accessor(list); + forward_kernel_shared<<>>( + num_atoms, get_accessor(positions), get_accessor(batch), + cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, list_accessor, + num_tiles, box); + }); return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; }