CLM enables training of large-scale 3D Gaussian Splatting scenes that exceed GPU memory capacity by also exploiting CPU memory.
By using CLM offloading, your 3DGS training can:
- Train large-scale scenes with 100+ million Gaussians on a single 24GB GPU and 128GB RAM
- Maintain rendering quality with a mathematically identical rendering formula
- Work with existing rendering kernels: We use off-the-shelf rendering kernels from gsplat. Our offloading design is orthogonal to these rendering kernels, making it easy to integrate with your own splatting pipelines
This codebase provides three modes of memory-efficient training strategies for your reference:
- no_offload: 3DGS training only on GPU. We optimize memory usage with engineering tricks on a single GPU. The show that CLM's offloading does not affect quality. (Implemented in
strategies/no_offload) - naive_offload: A simple CPU offloading implementation that stores all Gaussian attributes (xyz, etc.) and their optimizer states on CPU, loads parameters onto GPU in each iteration, and offloads gradients back to CPU in each batch. This demonstrates the simplest offloading strategy, though it is slower. (Implemented in
strategies/naive_offload) - clm_offload: Our most sophisticated offloading design that keeps selection-critical attributes on GPU while offloading others to CPU along with their optimizer states. It reduces memory usage to the extreme while maintaining good speed. The code is more complex but highly efficient. (Implemented in
strategies/clm_offload)
- Why use CLM-GS?
- How to run CLM?
- Understand CLM implementation and incorporate into your codebase
- Paper
- License
- Reference
The goal of CLM-GS is to solve GPU out-of-memory problems in 3DGS Training.
Traditional 3D Gaussian Splatting stores all parameters, optimizer states, and activation states on GPU, which severely limits the scene scale you can reconstruct due to GPU memory constraints (24GB on 4090). When the scene is very large and intricate, the large number of required Gaussians linearly increases memory consumption for parameters and optimizer states. When rendering high-resolution images, activation states also grow larger. As a result, GPU out-of-memory errors become a common issue.
CLM-GS addresses these memory constraints effectively. The table below compares GPU memory usage and training time across different scenes (102M means 102 million Gaussians) on our RTX 4090 testbed:
| Strategy | Bicycle (6M) | Rubble 4K (10M) | Rubble 4K (28M) | BigCity Aerial (102M) |
|---|---|---|---|---|
no_offload |
8.21 GB / 734 s | 16.81 GB / 11702 s | OOM | OOM |
naive_offload |
4.80 GB / 2481 s | 9.32 GB / 22254 s | 19.03 GB / 40820 s | OOM |
clm_offload |
3.01 GB / 1348 s | 7.05 GB / 12381 s | 13.0 GB / 24757 s | 20.79 GB / 11783 s |
The repository contains submodules, thus please check it out with
git clone git@github.com:nyu-systems/CLM-GS.git --recursiveEnsure you have Conda, GPU with compatible driver and CUDA environment installed on your machine, as prerequisites.
Note: PyTorch version >= 2.6 is required due to the usage of torch.nonzero_static() API.
Create and activate the conda environment:
conda create -n clm_gs python=3.10
conda activate clm_gsInstall PyTorch and related packages (please install a compatible Python and PyTorch set of packages for your system), for example:
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124Install additional dependencies:
pip install tqdm plyfile psutil numba opencv-python scipy matplotlib pandas imageio imageio-ffmpeg requests tabulateCompile and Install submodules locally:
pip install --no-build-isolation submodules/clm_kernels
pip install --no-build-isolation submodules/cpu-adam
pip install submodules/fast-tsp
pip install --no-build-isolation submodules/gsplat
pip install --no-build-isolation submodules/simple-knnThis repository trains a 3D Gaussian Splatting model using COLMAP-formatted input datasets. A COLMAP-formatted dataset contains a list of images with their corresponding camera poses, as well as an initial sparse point cloud that roughly represents the scene structure. This repository can reconstruct a detailed 3DGS model that captures intricate details from these images within the colmap-formatted dataset.
The following two COLMAP-formatted example datasets are available for use in the following guide:
- Mip360 Dataset: Download from https://jonbarron.info/mipnerf360/
- Rubble 4K Dataset: Download from https://huggingface.co/datasets/HexuZhao/mega_nerf_rubble_colmap/tree/main
No Offload (GPU-Only, for small scenes):
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python train.py -s <path to COLMAP dataset> --no_offload --bsz 4Naive Offload (Simple offloading for medium scenes):
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python train.py -s <path to COLMAP dataset> --naive_offload --bsz 4CLM Offload (Recommended for large scenes):
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python train.py -s <path to COLMAP dataset> --clm_offload --bsz 4Considerations about the flags for training your own dataset
To be simple, --no_offload is just a GPU-only training baseline for the other two offloading strategies to compare.
And the --naive_offload is an easy implementation but it is slow and cannot handle extreme large scene; --clm_offload is fast and can support even larger Gaussians model.
For detailed experimental setups and performance comparisons, see the "Why use CLM-GS?" section above and the Example Usages section below.
This codebase saves the dataset on disk and loads it on-demand during training to conserve CPU RAM. This is because, for extremely large datasets, you may not be able to decode the entire dataset into CPU RAM, let alone GPU memory. Note that streaming from disk to GPU is slower than streaming from CPU RAM to GPU.
Use --dataset_cache_and_stream_mode to control how images are handled:
Mode 1: "decode_images_in_advance" (Default)
This mode decodes JPG/PNG images into raw byte data when you first train on a dataset at a specific resolution, allowing on-demand streaming during training.
- Storage Location: Ensure decoded images are saved on a local disk rather than a network file system (NFS). Loading from NFS is significantly slower. The default decoded path is
--source_path/decode_{args.images}. If--source_pathis on an NFS, specify--decode_dataset_pathto point to a local disk location. - Disk Space: The decoded dataset can be very large. Calculate the required space as:
Disk Space = (num_images × image_height × image_width × 3) bytes - First-Time Setup: Initial decoding takes time, but the decoded dataset can be reused for subsequent training runs on the same scene.
If the decoded images path is corrupted or missing, simply remove the folder and rerun the decoding process.
Mode 2: "decode_images_on_demand"
This mode avoids pre-decoding images, saving disk storage space. However, decoding images on the CPU before each rendering pass is slower and consumes additional CPU computation.
For CLM offload mode, you can specify how many Gaussians to pre-allocate in CPU pinned memory:
python train.py -s <path to COLMAP dataset> \
--clm_offload \
--prealloc_capacity 40000000 # Pre-allocate for 40M GaussiansIf you don't specify --prealloc_capacity, the system automatically calculates the maximum number of Gaussians your workstation can support:
Number of Gaussians = (remaining CPU memory × 0.7) / (48 × 4 × 4 bytes)
Rule of thumb: Approximately 8 GB CPU memory per 10 million Gaussians.
Where:
- 48 = number of spherical harmonic coefficients offloaded to CPU RAM
- First 4 = bytes per float32
- Second 4 = storage multiplier (parameter + gradient + 2 optimizer states)
- 0.7 is a conservative multiplier that reserves memory for other workloads. For more aggressive memory allocation, specify
--prealloc_capacityexplicitly.
Note: --prealloc_capacity is only effective when --clm_offload is enabled.
This codebase uses microbatch pipelining with gradient accumulation. For each microbatch, we render one image and perform one backpropagation. The --bsz flag controls how many images to process before each optimizer step.
This design choice is important. Without microbatch pipelining, activation memory would grow linearly with batch size. With pipelining, activation memory remains constant at the level needed for rendering a single image.
Learning rate and momentum are scaled according to Grendel-GS rules when increasing --bsz. Currently, --clm_offload supports batch sizes of 4, 8, 16, 32, and 64.
CLM-specific Command Line Arguments for train.py
Use GPU-only mode (no parameter offloading). Best for small scenes.
Use simple offloading strategy. Suitable for medium scenes.
Use CLM offloading with retention optimization. Best for large scenes.
Number of Gaussians to pre-allocate in CPU pinned memory (e.g., 40000000 for 40M). Required for CLM offload with densification.
Batch size using micro-batch pipelining with gradient accumulation. --bsz 4 renders and backpropagates 4 images sequentially before each optimizer step. Images are processed one-by-one rather than simultaneously to reduce activation memory usage.
Please follow Gaussian Splatting's original codebase.
This section demonstrates CLM-GS on three different scales of scenes, from small benchmarks to extreme-scale reconstructions. Each example includes detailed reproduction instructions and usage pipelines.
In all examples, we set export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to reduce memory fragmentation in PyTorch's CUDA memory allocator.
All experiments were conducted on:
- Hardware: AMD Ryzen Threadripper PRO 5955WX (16-core), 128GB RAM, NVIDIA RTX 4090 24GB
- Interconnect: PCIe 4.0 x16 (CPU-GPU)
The Mip-NeRF 360 dataset provides standard benchmark scenes for evaluating quality and performance. While these scenes are small enough to fit in GPU memory, they serve as a baseline to verify that CLM offloading maintains quality while reducing memory usage.
The MegaNeRF Rubble scene at 4K resolution represents a real-world large-scale outdoor scene that exceeds standard GPU memory capacity. This example demonstrates CLM's ability to train a real-world large-scale scene from scratch.
The MatrixCity BigCity dataset represents the extreme upper bound of scene reconstruction with synthetic city-scale environments. This demonstrates CLM's capability to handle 100 million Gaussians. This serves as a stress test, requiring 128GB RAM and 24GB GPU memory to successfully train with 100 million Gaussians.
This section explains how the three strategies (--no_offload, --naive_offload, and --clm_offload) are implemented and how you can incorporate them into your own codebase. We first explain the common setup shared across all strategies, then detail each strategy's unique implementation.
All three strategies share the following optimizations:
-
Memory fragmentation reduction: We set
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:Truefor all training to reduce memory fragmentation in PyTorch's CUDA memory allocator. We observe severe fragmentation in 3DGS training wheretorch.cuda.max_memory_reserved()can be 2× larger thantorch.cuda.max_memory_allocated(). This occurs because rendering workloads vary significantly between training steps—--rendering one image may require 3× more Gaussians than another. Consequently, fragmentation in 3DGS is more severe than in typical neural network training. -
Microbatch pipelining: We use microbatch pipelining with gradient accumulation instead of rendering a batch of images simultaneously. For each microbatch, we render one image and perform one backpropagation, one by one. The
--bszflag controls how many images to process before each optimizer step. This design choice is critical: without microbatch pipelining, activation memory would grow linearly with batch size. With pipelining, activation memory remains constant at the level needed for rendering a single image. -
On-demand image loading: We save the dataset on disk and load it on-demand during training to conserve CPU RAM. For extremely large datasets, it may not be possible to decode the entire dataset into CPU RAM, let alone GPU memory. Note that streaming from disk to GPU is slower than streaming from CPU RAM to GPU.
The strategies/ folder contains the core implementations. We implement a base Gaussian model in strategies/base_gaussian_model.py and common rendering functions in strategies/base_engine.py. The three strategy folders (strategies/no_offload/, strategies/naive_offload/, and strategies/clm_offload/) inherit these base files and incorporate their respective design details.
Each strategy is explained in detail below:
- 📖 Overview of
--no_offloadStrategy - 📖 Overview of
--naive_offloadStrategy - 📖 Overview of
--clm_offloadStrategy
Our system design, memory management strategies, and scaling insights are documented in the paper below:
CLM: Removing the GPU Memory Barrier for 3D Gaussian Splatting
Hexu Zhao¹*, Xiwen Min¹*, Xiaoteng Liu¹, Moonjun Gong¹, Yiming Li¹, Ang Li²,³, Saining Xie¹, Jinyang Li¹, Aurojit Panda¹ (* co-first authors)
¹New York University, ²Pacific Northwest National Laboratory, ³University of Washington
@inproceedings{zhao2025clm,
title={CLM: Removing the GPU Memory Barrier for 3D Gaussian Splatting},
author={Hexu Zhao and Xiwen Min and Xiaoteng Liu and Moonjun Gong and Yiming Li and Ang Li and Saining Xie and Jinyang Li and Aurojit Panda},
booktitle={Proceedings of the 2026 International Conference on Architectural Support for Programming Languages and Operating Systems (ASPLOS'26)},
year={2026},
address={Pittsburgh, PA, USA},
url={https://arxiv.org/abs/2511.04951}
}Please use Black with default settings to format code.
# Install
pip install black
# Format all files
black . Distributed under the Apache License Version 2.0 License. See LICENSE.txt for more information.
- Bernhard Kerbl, Georgios Kopanas, Thomas Leimkühler, and George Drettakis. 3d gaussian splatting for real-time radiance field rendering. ACM Transactions on Graphics, July 2023. URL: https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/.
