Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions INSTALL.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
## Installation

### Requirements:
- PyTorch 1.3 (1.4 may cause some errors.)
- python >= 3.8
- PyTorch == 2.5.1
- torchvision from master
- cocoapi
- yacs
- matplotlib
- GCC >= 4.9
- OpenCV
- CUDA >= 9.2
- CUDA >= 12.4


### Option 1: Step-by-step installation
Expand All @@ -18,54 +19,50 @@
# for that, check that `which conda`, `which pip` and `which python` points to the
# right path. From a clean conda env, this is what you need to do

conda create --name MEGA -y python=3.7
conda create --name MEGA -y python=3.12
source activate MEGA

# this installs the right pip and dependencies for the fresh python
conda install ipython pip

# mega and coco api dependencies
pip install ninja yacs cython matplotlib tqdm opencv-python scipy
pip install build wheel installer ninja yacs cython matplotlib tqdm opencv-python scipy numpy

# follow PyTorch installation in https://pytorch.org/get-started/locally/
# we give the instructions for CUDA 10.0
conda install pytorch=1.3.0 torchvision cudatoolkit=10.0 -c pytorch
pip install torch torchvision torchaudio

export INSTALL_DIR=$PWD

# install pycocotools
cd $INSTALL_DIR
git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
python setup.py build_ext install
python -m build --wheel --no-isolation
python -m installer dist/*.whl

# install cityscapesScripts
cd $INSTALL_DIR
git clone https://github.com/mcordts/cityscapesScripts.git
cd cityscapesScripts/
python setup.py build_ext install
python -m build --wheel --no-isolation
python -m installer dist/*.whl

# install apex
cd $INSTALL_DIR
git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cuda_ext --cpp_ext
python -m build --wheel --no-isolation
python -m installer dist/*.whl

# install PyTorch Detection
cd $INSTALL_DIR
git clone https://github.com/Scalsol/mega.pytorch.git
cd mega.pytorch

# the following will install the lib with
# symbolic links, so that you can modify
# the files if you want and won't need to
# re-build it
python setup.py build develop

pip install 'pillow<7.0.0'
python -m build --wheel --no-isolation
python -m installer dist/*.whl

unset INSTALL_DIR

# or if you are on macOS
# MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py build develop
```
```
2 changes: 1 addition & 1 deletion demo/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def overlay_class_names(self, image, predictions):
x, y = box[:2]
s = template.format(label, score)
cv2.putText(
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2
image, s, (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2
)

return image
Expand Down
19 changes: 9 additions & 10 deletions mega_core/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <ATen/ceil_div.h>
#include <ATen/cuda/ThrustAllocator.h>
#include <ATen/cuda/Atomic.cuh>

// TODO make it in a common file
#define CUDA_1D_KERNEL_LOOP(i, n) \
Expand Down Expand Up @@ -272,11 +271,11 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)output_size, 512L), 4096L));
dim3 block(512);

if (output.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return output;
}

Expand All @@ -294,7 +293,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
rois.contiguous().data<scalar_t>(),
output.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return output;
}

Expand All @@ -317,12 +316,12 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)grad.numel(), 512L), 4096L));
dim3 block(512);

// handle possibly empty gradients
if (grad.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

Expand All @@ -341,6 +340,6 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
20 changes: 9 additions & 11 deletions mega_core/csrc/cuda/ROIPool_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

#include <ATen/ceil_div.h>
#include <ATen/cuda/ThrustAllocator.h>
#include <ATen/cuda/Atomic.cuh>

// TODO make it in a common file
#define CUDA_1D_KERNEL_LOOP(i, n) \
Expand Down Expand Up @@ -126,11 +124,11 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)output_size, 512L), 4096L));
dim3 block(512);

if (output.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, argmax);
}

Expand All @@ -148,7 +146,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
output.data<scalar_t>(),
argmax.data<int>());
});
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, argmax);
}

Expand All @@ -173,12 +171,12 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)grad.numel(), 512L), 4096L));
dim3 block(512);

// handle possibly empty gradients
if (grad.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

Expand All @@ -197,6 +195,6 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
20 changes: 9 additions & 11 deletions mega_core/csrc/cuda/SigmoidFocalLoss_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
// cyfu@cs.unc.edu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

#include <ATen/ceil_div.h>
#include <ATen/cuda/ThrustAllocator.h>
#include <ATen/cuda/Atomic.cuh>
#include <cfloat>

// TODO make it in a common file
Expand Down Expand Up @@ -117,12 +115,12 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
auto losses_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)losses_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)losses_size, 512L), 4096L));

dim3 block(512);

if (losses.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return losses;
}

Expand All @@ -137,7 +135,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
num_samples,
losses.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return losses;
}

Expand All @@ -162,11 +160,11 @@ at::Tensor SigmoidFocalLoss_backward_cuda(
auto d_logits_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)d_logits_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)d_logits_size, 512L), 4096L));
dim3 block(512);

if (d_logits.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return d_logits;
}

Expand All @@ -183,7 +181,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda(
d_logits.data<scalar_t>());
});

THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return d_logits;
}

Loading