From f38d673a66670c50d7718b595bfc6297dee6ee95 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 7 Feb 2024 17:18:15 +0100 Subject: [PATCH 1/4] Make neighbor list compatible with float16 and bfloat16 --- torchmdnet/extensions/neighbors/common.cuh | 25 +++++++++++++++++-- .../neighbors/neighbors_cuda_brute.cuh | 2 +- .../neighbors/neighbors_cuda_cell.cuh | 6 ++--- .../neighbors/neighbors_cuda_shared.cuh | 2 +- 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/torchmdnet/extensions/neighbors/common.cuh b/torchmdnet/extensions/neighbors/common.cuh index 375d9b5a8..dbf9661b2 100644 --- a/torchmdnet/extensions/neighbors/common.cuh +++ b/torchmdnet/extensions/neighbors/common.cuh @@ -10,6 +10,8 @@ #include #include +using at::BFloat16; +using at::Half; using c10::cuda::CUDAStreamGuard; using c10::cuda::getCurrentCUDAStream; using torch::empty; @@ -23,6 +25,9 @@ using torch::autograd::AutogradContext; using torch::autograd::Function; using torch::autograd::tensor_list; +#define DISPATCH_FOR_ALL_FLOAT_TYPES(...)\ + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, __VA_ARGS__)\ + template using Accessor = torch::PackedTensorAccessor32; @@ -54,6 +59,22 @@ template <> struct vec3 { using type = double3; }; +struct Half3 { + Half x, y, z; +}; + +template <> struct vec3 { + using type = Half3; +}; + +struct BFloat163 { + BFloat16 x, y, z; +}; + +template <> struct vec3 { + using type = BFloat163; +}; + template using scalar3 = typename vec3::type; /* @@ -194,12 +215,12 @@ __device__ auto apply_pbc(scalar3 delta, const KernelAccessor +template __device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j, bool use_periodic, const KernelAccessor& box) { scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z}; if (use_periodic) { - delta = apply_pbc(delta, box); + delta = apply_pbc(delta, box); } return delta; } diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh index 7df3317d4..3f6008c2a 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh @@ -100,7 +100,7 @@ forward_brute(const Tensor& positions, const Tensor& batch, const Tensor& in_box const uint64_t num_all_pairs = num_atoms * (num_atoms - 1UL) / 2UL; const uint64_t num_threads = 128; const uint64_t num_blocks = std::max((num_all_pairs + num_threads - 1UL) / num_threads, 1UL); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { + DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { PairListAccessor list_accessor(list); auto box = triclinic::get_box_accessor(box_vectors, use_periodic); const scalar_t cutoff_upper_ = cutoff_upper.to(); diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh index 193a79898..36db3d4a5 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh @@ -116,7 +116,7 @@ static auto sortAtomsByCellIndex(const Tensor& positions, const Tensor& box_size const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "assignHash", [&] { + DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "assignHash", [&] { scalar_t cutoff_ = cutoff.to(); scalar3 box_size_ = {box_size[0][0].item(), box_size[1][1].item(), @@ -229,7 +229,7 @@ CellList constructCellList(const Tensor& positions, const Tensor& batch, const T cl.sorted_batch = batch.index_select(0, cl.sorted_indices); // Step 3 int3 cell_dim; - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "computeCellDim", [&] { + DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "computeCellDim", [&] { scalar_t cutoff_ = cutoff.to(); scalar3 box_size_ = {box_size[0][0].item(), box_size[1][1].item(), @@ -368,7 +368,7 @@ forward_cell(const Tensor& positions, const Tensor& batch, const Tensor& in_box_ const auto stream = getCurrentCUDAStream(positions.get_device()); { // Traverse the cell list to find the neighbors const CUDAStreamGuard guard(stream); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { + DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "forward", [&] { const scalar_t cutoff_upper_ = cutoff_upper.to(); TORCH_CHECK(cutoff_upper_ > 0, "Expected cutoff_upper to be positive"); const scalar_t cutoff_lower_ = cutoff_lower.to(); diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh index 9c4523f50..cbbdb9b7b 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh @@ -104,7 +104,7 @@ forward_shared(const Tensor& positions, const Tensor& batch, const Tensor& in_bo const auto stream = getCurrentCUDAStream(positions.get_device()); PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic); const CUDAStreamGuard guard(stream); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { + DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { const scalar_t cutoff_upper_ = cutoff_upper.to(); const scalar_t cutoff_lower_ = cutoff_lower.to(); auto box = triclinic::get_box_accessor(box_vectors, use_periodic); From 4434dd9adf9178469faad6275844773ebb9f5da2 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 9 Feb 2024 16:55:44 +0100 Subject: [PATCH 2/4] Replace frobenius_norm by norm --- torchmdnet/extensions/neighbors/neighbors_cpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/extensions/neighbors/neighbors_cpu.cpp b/torchmdnet/extensions/neighbors/neighbors_cpu.cpp index 86e0578aa..4848fa5ad 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/extensions/neighbors/neighbors_cpu.cpp @@ -95,7 +95,7 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& in_box_vecto deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) - scale1 * box_vectors.index({pair_batch, 0, 0})); } - distances = frobenius_norm(deltas, 1); + distances = torch::linalg::norm(deltas, c10::nullopt, 1, false, c10::nullopt); mask = (distances < cutoff_upper) * (distances >= cutoff_lower); neighbors = neighbors.index({Slice(), mask}); deltas = deltas.index({mask, Slice()}); From 72272f5c0ab91d5da35fb4c08fe34295244d6525 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 9 Feb 2024 16:56:12 +0100 Subject: [PATCH 3/4] Remove old sqrt overloads --- torchmdnet/extensions/neighbors/common.cuh | 8 -------- torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh | 2 +- torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh | 2 +- torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh | 2 +- 4 files changed, 3 insertions(+), 11 deletions(-) diff --git a/torchmdnet/extensions/neighbors/common.cuh b/torchmdnet/extensions/neighbors/common.cuh index dbf9661b2..6717d4c27 100644 --- a/torchmdnet/extensions/neighbors/common.cuh +++ b/torchmdnet/extensions/neighbors/common.cuh @@ -39,14 +39,6 @@ inline Accessor get_accessor(const Tensor& tensor) { return tensor.packed_accessor32(); }; -template __device__ __forceinline__ scalar_t sqrt_(scalar_t x){}; -template <> __device__ __forceinline__ float sqrt_(float x) { - return ::sqrtf(x); -}; -template <> __device__ __forceinline__ double sqrt_(double x) { - return ::sqrt(x); -}; - template struct vec3 { using type = void; }; diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh index 3f6008c2a..bd6a53f77 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh @@ -40,7 +40,7 @@ __global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor= cutoff_lower2) { - const scalar_t r2 = sqrt_(distance2); + const scalar_t r2 = ::sqrt(distance2); addAtomPairToList(list, row, column, delta, r2, list.include_transpose); } } diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh index 36db3d4a5..987905253 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh @@ -270,7 +270,7 @@ __device__ void addNeighborPair(PairListAccessor& list, const int i, c const int ni = max(i, j); const int nj = min(i, j); const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0); - const scalar_t distance = sqrt_(distance2); + const scalar_t distance = ::sqrt(distance2); delta = {delta_sign * delta.x, delta_sign * delta.y, delta_sign * delta.z}; addAtomPairToList(list, ni, nj, delta, distance, requires_transpose); } diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh index cbbdb9b7b..4a5f248f7 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh @@ -62,7 +62,7 @@ __global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor= cutoff_lower2) { const bool requires_transpose = list.include_transpose && !(cur_j == id); - const auto distance = sqrt_(distance2); + const scalar_t distance = ::sqrt(distance2); addAtomPairToList(list, id, cur_j, delta, distance, requires_transpose); } } From 5eef9653796ca63d4b31f94ed2c98333294b09d7 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 9 Feb 2024 16:56:26 +0100 Subject: [PATCH 4/4] format --- .../neighbors/neighbors_cuda_shared.cuh | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh index 4a5f248f7..d459da119 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh @@ -104,21 +104,22 @@ forward_shared(const Tensor& positions, const Tensor& batch, const Tensor& in_bo const auto stream = getCurrentCUDAStream(positions.get_device()); PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic); const CUDAStreamGuard guard(stream); - DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { - const scalar_t cutoff_upper_ = cutoff_upper.to(); - const scalar_t cutoff_lower_ = cutoff_lower.to(); - auto box = triclinic::get_box_accessor(box_vectors, use_periodic); - TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); - constexpr int BLOCKSIZE = 64; - const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); - const int num_threads = BLOCKSIZE; - const int num_tiles = num_blocks; - PairListAccessor list_accessor(list); - forward_kernel_shared<<>>( - num_atoms, get_accessor(positions), get_accessor(batch), - cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, list_accessor, num_tiles, - box); - }); + DISPATCH_FOR_ALL_FLOAT_TYPES( + positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { + const scalar_t cutoff_upper_ = cutoff_upper.to(); + const scalar_t cutoff_lower_ = cutoff_lower.to(); + auto box = triclinic::get_box_accessor(box_vectors, use_periodic); + TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); + constexpr int BLOCKSIZE = 64; + const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); + const int num_threads = BLOCKSIZE; + const int num_tiles = num_blocks; + PairListAccessor list_accessor(list); + forward_kernel_shared<<>>( + num_atoms, get_accessor(positions), get_accessor(batch), + cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, list_accessor, + num_tiles, box); + }); return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; }