File: fps_cuda.cu

package info (click to toggle)
pytorch-cluster 1.6.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 648 kB
  • sloc: cpp: 2,076; python: 1,081; sh: 53; makefile: 8
file content (110 lines) | stat: -rw-r--r-- 3,357 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include "fps_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

#define THREADS 256

template <typename scalar_t>
__global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
                           const int64_t *out_ptr, const int64_t *start,
                           scalar_t *dist, int64_t *out, int64_t dim) {

  const int64_t thread_idx = threadIdx.x;
  const int64_t batch_idx = blockIdx.x;

  const int64_t start_idx = ptr[batch_idx];
  const int64_t end_idx = ptr[batch_idx + 1];

  __shared__ scalar_t best_dist[THREADS];
  __shared__ int64_t best_dist_idx[THREADS];

  if (thread_idx == 0) {
    out[out_ptr[batch_idx]] = start_idx + start[batch_idx];
  }

  for (int64_t m = out_ptr[batch_idx] + 1; m < out_ptr[batch_idx + 1]; m++) {
    __syncthreads();
    int64_t old = out[m - 1];

    scalar_t best = (scalar_t)-1.;
    int64_t best_idx = 0;

    for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) {
      scalar_t tmp, dd = (scalar_t)0.;
      for (int64_t d = 0; d < dim; d++) {
        tmp = src[dim * old + d] - src[dim * n + d];
        dd += tmp * tmp;
      }
      dd = min(dist[n], dd);
      dist[n] = dd;
      if (dd > best) {
        best = dd;
        best_idx = n;
      }
    }

    best_dist[thread_idx] = best;
    best_dist_idx[thread_idx] = best_idx;

    for (int64_t i = 1; i < THREADS; i *= 2) {
      __syncthreads();
      if ((thread_idx + i) < THREADS &&
          best_dist[thread_idx] < best_dist[thread_idx + i]) {
        best_dist[thread_idx] = best_dist[thread_idx + i];
        best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i];
      }
    }

    __syncthreads();
    if (thread_idx == 0) {
      out[m] = best_dist_idx[0];
    }
  }
}

torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
                       torch::Tensor ratio, bool random_start) {

  CHECK_CUDA(src);
  CHECK_CUDA(ptr);
  CHECK_CUDA(ratio);
  CHECK_INPUT(ptr.dim() == 1);
  cudaSetDevice(src.get_device());

  src = src.view({src.size(0), -1}).contiguous();
  ptr = ptr.contiguous();
  auto batch_size = ptr.numel() - 1;

  auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
  auto out_ptr = deg.toType(ratio.scalar_type()) * ratio;
  out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
  out_ptr = torch::cat({torch::zeros({1}, ptr.options()), out_ptr}, 0);

  torch::Tensor start;
  if (random_start) {
    start = torch::rand(batch_size, src.options());
    start = (start * deg.toType(ratio.scalar_type())).toType(torch::kLong);
  } else {
    start = torch::zeros({batch_size}, ptr.options());
  }

  auto dist = torch::full(src.size(0), 5e4, src.options());

  auto out_size = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
  auto out = torch::empty({out_size[0]}, out_ptr.options());

  auto stream = at::cuda::getCurrentCUDAStream();
  auto scalar_type = src.scalar_type();
  AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
    fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>(
        src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
        out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),
        dist.data_ptr<scalar_t>(), out.data_ptr<int64_t>(), src.size(1));
  });

  return out;
}