1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
|
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include "cuda_helpers.h"
namespace vision {
namespace ops {
namespace {
int const threadsPerBlock = sizeof(unsigned long long) * 8;
template <typename T>
__device__ inline bool devIoU(
T const* const a,
T const* const b,
const float threshold) {
T left = max(a[0], b[0]), right = min(a[2], b[2]);
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
using acc_T = at::acc_type<T, /*is_cuda=*/true>;
acc_T interS = (acc_T)width * height;
acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]);
acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]);
return (interS / (Sa + Sb - interS)) > threshold;
}
template <typename T>
__global__ void nms_kernel_impl(
int n_boxes,
double iou_threshold,
const T* dev_boxes,
unsigned long long* dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
if (row_start > col_start)
return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
__shared__ T block_boxes[threadsPerBlock * 4];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 4 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0];
block_boxes[threadIdx.x * 4 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1];
block_boxes[threadIdx.x * 4 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2];
block_boxes[threadIdx.x * 4 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 4;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU<T>(cur_box, block_boxes + i * 4, iou_threshold)) {
t |= 1ULL << i;
}
}
const int col_blocks = ceil_div(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
at::Tensor nms_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor");
TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor");
TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(
dets.size(1) == 4,
"boxes should have 4 elements in dimension 1, got ",
dets.size(1));
TORCH_CHECK(
scores.dim() == 1,
"scores should be a 1d tensor, got ",
scores.dim(),
"D");
TORCH_CHECK(
dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0))
at::cuda::CUDAGuard device_guard(dets.device());
if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t).contiguous();
int dets_num = dets.size(0);
const int col_blocks = ceil_div(dets_num, threadsPerBlock);
at::Tensor mask =
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
dets_sorted.scalar_type(), "nms_kernel", [&] {
nms_kernel_impl<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num,
iou_threshold,
dets_sorted.data_ptr<scalar_t>(),
(unsigned long long*)mask.data_ptr<int64_t>());
});
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host =
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
at::Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();
int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}
} // namespace ops
} // namespace vision
|