A high-performance PyTorch library for 3D point cloud operations, including Chamfer Distance, Earth Mover's Distance (EMD), and K-Nearest Neighbors (KNN) with CUDA support and built-in performance benchmarking.
torch-point-ops provides efficient and well-tested implementations of essential point cloud operations, designed to be easily integrated into any deep learning pipeline. With optimized CUDA kernels, multi-precision support, and comprehensive testing, it's the go-to library for high-performance 3D point cloud processing.
- Chamfer Distance: A fast and efficient implementation of the Chamfer Distance between two point clouds.
- Earth Mover's Distance (EMD): An implementation of the Earth Mover's Distance for comparing point cloud distributions.
- K-Nearest Neighbors (KNN): High-performance KNN search with multiple optimized kernel versions and automatic version selection based on problem size.
- Furthest Point Sampling (FPS): Efficient implementation of furthest point sampling for point cloud downsampling with optimized CUDA kernels and gradient-aware point gathering.
- π₯ Multi-Precision Support: Native support for float16, float32, and float64 with optimized atomic operations (up to 6x speedup on half precision).
- CUDA Support: GPU-accelerated operations for high-performance computation.
- β‘ Optimized Atomic Operations: Uses
fastSpecializedAtomicAddfor maximum GPU utilization and performance. - Performance Benchmarking: Built-in FLOPs benchmarking to measure computational efficiency.
- Fully Tested: Includes a comprehensive test suite to ensure correctness and reliability.
- Production Ready: Optimized for both research and deployment environments.
Follow these instructions to set up torch-point-ops in a local development environment.
- Python >= 3.13
- PyTorch >= 2.7.1
- A C++17 compatible compiler
- A CUDA-enabled GPU
uvfor package management- (Optional)
patchelf>= 0.14 for building the wheel
-
Clone the repository:
git clone https://github.com/satyajitghana/torch-point-ops.git cd torch-point-ops -
Create and activate a virtual environment:
python3 -m venv .venv source .venv/bin/activate -
Install dependencies using
uv:pip install uv MAX_JOBS=10 uv pip install -e .[dev]
This command installs the library in editable mode, allowing you to modify the source code and see the changes immediately.
MAX_JOBS=10 uv pip install -e .
For the best development experience and to ensure code quality, set up automated formatting and linting:
bash scripts/setup_hooks.shThis gives you two options:
Option 1: Pre-commit Framework (Recommended for Teams)
- π Industry-standard tool used by major projects
- β¨ Runs Black, Ruff, and other quality checks
- π Auto-updates hook versions
- π‘οΈ More robust than simple git hooks
Option 2: Simple Git Hook (Basic)
- π§ Simple Black formatter hook
- π Good for solo development
- β‘ Lightweight setup
Production Note: This repo uses GitHub Actions to enforce formatting on all PRs, so your code will be checked regardless! The hooks just help catch issues early. π―
To create a distributable wheel, run the provided build script. This is useful for installing the package in other environments without needing to build from source every time.
bash scripts/build_wheel.shThis script will:
- Install
patchelf(if not already present). - Build the wheel using
uv. - Repair the wheel with
auditwheelto bundle the required shared libraries.
The final wheel will be located in the dist/ directory.
Here's how you can use the Chamfer Distance, EMD, and KNN functions in your project:
import torch
from torch_point_ops.chamfer import chamfer_distance
from torch_point_ops.emd import earth_movers_distance
from torch_point_ops.knn import knn_points, KNearestNeighbors
from torch_point_ops.fps import furthest_point_sampling, gather_points, FarthestPointSampling
# Create two random point clouds on the GPU
p1 = torch.rand(1, 128, 3).cuda()
p2 = torch.rand(1, 128, 3).cuda()
# --- Chamfer Distance ---
dist1, dist2 = chamfer_distance(p1, p2)
loss = dist1.mean() + dist2.mean()
print(f"Chamfer Distance Loss: {loss.item()}")
# --- Earth Mover's Distance ---
emd_loss = earth_movers_distance(p1, p2)
print(f"Earth Mover's Distance Loss: {emd_loss.mean().item()}")
# --- K-Nearest Neighbors ---
# Find 5 nearest neighbors from p2 for each point in p1
knn_result = knn_points(p1, p2, K=5, return_nn=True)
dists = knn_result.dists # Shape: [1, 128, 5] - distances to nearest neighbors
idx = knn_result.idx # Shape: [1, 128, 5] - indices of nearest neighbors
knn = knn_result.knn # Shape: [1, 128, 5, 3] - coordinates of nearest neighbors
print(f"KNN distances shape: {dists.shape}")
print(f"Average distance to nearest neighbor: {dists[:, :, 0].mean().item()}")
# Using the KNearestNeighbors module for integration in neural networks
knn_module = KNearestNeighbors(K=5, return_nn=False).cuda()
dists, idx = knn_module(p1, p2)
# --- Furthest Point Sampling ---
# Downsample point cloud to 64 points using FPS
fps_indices = furthest_point_sampling(p1, 64)
print(f"FPS indices shape: {fps_indices.shape}") # [1, 64]
# Gather the sampled points (with gradient support)
p1_features = torch.rand(1, 6, 128).cuda() # 6 features per point
sampled_features = gather_points(p1_features, fps_indices)
print(f"Sampled features shape: {sampled_features.shape}") # [1, 6, 64]
# Using the FPS module for integration in neural networks
fps_module = FarthestPointSampling(nsamples=64, return_gathered=True).cuda()
indices, sampled_points = fps_module(p1)
print(f"FPS module output shapes: {indices.shape}, {sampled_points.shape}")
# --- Quick Furthest Point Sampling (Accelerated) ---
# For large point clouds, use Quick FPS for significant speedup
large_points = torch.rand(1, 2000, 3).cuda()
# Quick FPS with spatial partitioning (much faster for large point clouds)
quick_indices = quick_furthest_point_sampling(large_points, 128, kd_depth=6)
print(f"Quick FPS indices shape: {quick_indices.shape}") # [1, 128]
# Convenience function that combines Quick FPS and gathering
from torch_point_ops.fps import quick_farthest_point_sample_and_gather
quick_indices, quick_sampled = quick_farthest_point_sample_and_gather(large_points, 128, kd_depth=6)
print(f"Quick FPS sampled points shape: {quick_sampled.shape}") # [1, 128, 3]
# Using the Quick FPS module for neural networks
quick_fps_module = QuickFarthestPointSampling(nsamples=128, kd_depth=6, return_gathered=True).cuda()
quick_indices, quick_points = quick_fps_module(large_points)
print(f"Quick FPS module shapes: {quick_indices.shape}, {quick_points.shape}")torch-point-ops stands out with its comprehensive multi-precision support and cutting-edge optimizations that most other point cloud libraries lack:
Unlike other libraries that are limited to float32, torch-point-ops provides native support for all PyTorch floating-point types:
import torch
from torch_point_ops.chamfer import chamfer_distance
from torch_point_ops.knn import knn_points
from torch_point_ops.fps import furthest_point_sampling, gather_points, quick_furthest_point_sampling
# Half precision (float16) - Perfect for memory-constrained environments
p1_half = torch.rand(1, 1024, 3, dtype=torch.float16).cuda()
p2_half = torch.rand(1, 1024, 3, dtype=torch.float16).cuda()
# Chamfer Distance with half precision
dist1, dist2 = chamfer_distance(p1_half, p2_half)
# KNN with half precision - up to 6x faster with optimized atomic operations
knn_result = knn_points(p1_half, p2_half, K=8)
# FPS with half precision - efficient point cloud downsampling
fps_indices = furthest_point_sampling(p1_half, 128)
features_half = torch.rand(1, 6, 1024, dtype=torch.float16).cuda()
sampled_features = gather_points(features_half, fps_indices)
# Quick FPS with half precision - accelerated for large point clouds
quick_indices = quick_furthest_point_sampling(p1_half, 128, kd_depth=6)
# Single precision (float32) - Standard for most applications
p1_single = torch.rand(1, 1024, 3, dtype=torch.float32).cuda()
p2_single = torch.rand(1, 1024, 3, dtype=torch.float32).cuda()
dist1, dist2 = chamfer_distance(p1_single, p2_single)
knn_result = knn_points(p1_single, p2_single, K=8)
fps_indices = furthest_point_sampling(p1_single, 128)
quick_indices = quick_furthest_point_sampling(p1_single, 128, kd_depth=6)
# Double precision (float64) - For research requiring high numerical precision
p1_double = torch.rand(1, 1024, 3, dtype=torch.float64).cuda()
p2_double = torch.rand(1, 1024, 3, dtype=torch.float64).cuda()
dist1, dist2 = chamfer_distance(p1_double, p2_double)
knn_result = knn_points(p1_double, p2_double, K=8)
fps_indices = furthest_point_sampling(p1_double, 128)
quick_indices = quick_furthest_point_sampling(p1_double, 128, kd_depth=6)- Fast Specialized Atomic Operations: Our implementation uses PyTorch's
fastSpecializedAtomicAddfor up to 6x performance improvement on half-precision operations. - Templated CUDA Kernels: All operations are templated to work natively with any precision without performance overhead.
- Multiple Kernel Versions: KNN implementation includes 4 optimized kernel versions (V0-V3) with automatic version selection based on problem size and hardware characteristics.
- Register-Based MinK Operations: KNN uses optimized register-based data structures with template specializations for K=1,2 for maximum performance.
- Optimized FPS Kernels: FPS uses templated block sizes and shared memory reduction for maximum throughput across different point cloud sizes.
- Memory Efficiency: Half precision support reduces memory usage by 50%, enabling larger point clouds on the same hardware.
- Gradient Stability: Comprehensive gradient testing across all precisions ensures reliable backpropagation.
| Feature | torch-point-ops | Other Libraries |
|---|---|---|
| KNN Operations | β 4 optimized kernels + auto-selection | β Basic/slow implementations |
| FPS Operations | β Templated kernels + shared memory | β Basic/unoptimized |
| Quick FPS | β Spatial hash acceleration + GPU adaptation | β Not available |
| Half Precision (float16) | β Native support | β Usually unsupported |
| Double Precision (float64) | β Full support | β Limited/no support |
| Dynamic GPU Scaling | β Auto-adapts to A100/H100/etc | β Fixed thread counts |
| Optimized Atomics | β 6x faster half precision | β Standard atomics only |
| Register-Based MinK | β Template specializations | β Generic heap-based |
| Memory Efficiency | β 50% reduction with fp16 | β fp32 only |
| Gradient Testing | β All precisions tested | β Limited testing |
Want to see how fast these operations really are? We've included a comprehensive FLOPs benchmarking script that tests all operations with multiple precisions and compares eager mode with torch.compile.
# Activate your environment first
source .venv/bin/activate
# Run the FLOPs benchmark
python benchmark_flops.pyThe following table shows the performance for the B16_N2048_M2048 configuration on an NVIDIA GeForce RTX 3090. torch.compile with reduce-overhead or max-autotune modes can provide significant speedups, especially for EMD.
| Operation | Precision | Mode | Runtime (ms) | Speedup vs Eager |
|---|---|---|---|---|
| KNN (K=16) | FP16 | Compile (default) | 1.173 | 0.94x |
| KNN (K=16) | FP16 | Compile (max-autotune) | 1.184 | 0.93x |
| KNN (K=16) | FP16 | Compile (reduce-overhead) | 1.173 | 0.94x |
| KNN (K=16) | FP16 | Eager | 1.106 | 1.00x |
| KNN (K=16) | FP32 | Compile (default) | 0.995 | 0.95x |
| KNN (K=16) | FP32 | Compile (max-autotune) | 0.972 | 0.97x |
| KNN (K=16) | FP32 | Compile (reduce-overhead) | 1.021 | 0.92x |
| KNN (K=16) | FP32 | Eager | 0.943 | 1.00x |
| FPS (N=128) | FP16 | Compile (default) | 0.421 | 1.02x |
| FPS (N=128) | FP16 | Compile (max-autotune) | 0.418 | 1.03x |
| FPS (N=128) | FP16 | Compile (reduce-overhead) | 0.419 | 1.03x |
| FPS (N=128) | FP16 | Eager | 0.430 | 1.00x |
| FPS (N=128) | FP32 | Compile (default) | 0.512 | 1.01x |
| FPS (N=128) | FP32 | Compile (max-autotune) | 0.508 | 1.02x |
| FPS (N=128) | FP32 | Compile (reduce-overhead) | 0.510 | 1.02x |
| FPS (N=128) | FP32 | Eager | 0.518 | 1.00x |
| Quick FPS (N=128) | FP16 | Compile (default) | 0.298 | 1.44x |
| Quick FPS (N=128) | FP16 | Compile (max-autotune) | 0.291 | 1.48x |
| Quick FPS (N=128) | FP16 | Compile (reduce-overhead) | 0.295 | 1.46x |
| Quick FPS (N=128) | FP16 | Eager | 0.430 | 1.00x |
| Quick FPS (N=128) | FP32 | Compile (default) | 0.365 | 1.42x |
| Quick FPS (N=128) | FP32 | Compile (max-autotune) | 0.358 | 1.45x |
| Quick FPS (N=128) | FP32 | Compile (reduce-overhead) | 0.362 | 1.43x |
| Quick FPS (N=128) | FP32 | Eager | 0.518 | 1.00x |
| Chamfer | FP16 | Compile (default) | 0.558 | 1.00x |
| Chamfer | FP16 | Compile (max-autotune) | 0.557 | 1.00x |
| Chamfer | FP16 | Compile (reduce-overhead) | 0.558 | 1.00x |
| Chamfer | FP16 | Eager | 0.556 | 1.00x |
| Chamfer | FP32 | Compile (default) | 0.233 | 0.98x |
| Chamfer | FP32 | Compile (max-autotune) | 0.230 | 0.99x |
| Chamfer | FP32 | Compile (reduce-overhead) | 0.133 | 1.71x |
| Chamfer | FP32 | Eager | 0.228 | 1.00x |
| EMD | FP32 | Compile (default) | 32.378 | 1.00x |
| EMD | FP32 | Compile (max-autotune) | 0.326 | 99.28x |
| EMD | FP32 | Compile (reduce-overhead) | 0.328 | 98.68x |
| EMD | FP32 | Eager | 32.366 | 1.00x |
Runtimes are for a single forward pass on an NVIDIA GPU. Speedup is relative to the Eager mode of the same precision.
The benchmark script tests various configurations and provides detailed timing statistics, theoretical FLOP counts, and performance analysis.
To ensure everything is working correctly, it is recommended to run the local test suite.
Note: These tests require a CUDA-enabled GPU to run.
pytestThis command will automatically discover and run all tests in the tests/ directory.
Contributions are welcome! If you have a feature request, bug report, or want to contribute to the code, please open an issue or submit a pull request on the GitHub repository.
This project uses Black for code formatting. Please ensure your code is formatted before submitting:
# Format all Python files
black .
# Check formatting without making changes
black --check .Pro tip: Set up the git hooks (see above) to automatically format your code! π
This repository uses GitHub Actions to ensure code quality on every PR:
- β Black formatting - Code must be properly formatted
- π Ruff linting - Code must pass all lint checks
- π« PR blocking - Improperly formatted code cannot be merged
The workflow runs on Python 3.11, 3.12, and 3.13 to ensure compatibility.
This project is licensed under the MIT License. See the LICENSE file for more details.