File: grid_cuda.cu

package info (click to toggle)
pytorch-cluster 1.6.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 648 kB
  • sloc: cpp: 2,076; python: 1,081; sh: 53; makefile: 8
file content (72 lines) | stat: -rw-r--r-- 2,286 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include "grid_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void grid_kernel(const scalar_t *pos, const scalar_t *size,
                            const scalar_t *start, const scalar_t *end,
                            int64_t *out, int64_t D, int64_t numel) {
  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

  if (thread_idx < numel) {
    int64_t c = 0, k = 1;
    for (int64_t d = 0; d < D; d++) {
      scalar_t p = pos[thread_idx * D + d] - start[d];
      c += (int64_t)(p / size[d]) * k;
      k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
    }
    out[thread_idx] = c;
  }
}

torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
                        torch::optional<torch::Tensor> optional_start,
                        torch::optional<torch::Tensor> optional_end) {
  CHECK_CUDA(pos);
  CHECK_CUDA(size);
  cudaSetDevice(pos.get_device());

  if (optional_start.has_value())
    CHECK_CUDA(optional_start.value());
  if (optional_start.has_value())
    CHECK_CUDA(optional_start.value());

  pos = pos.view({pos.size(0), -1}).contiguous();
  size = size.contiguous();

  CHECK_INPUT(size.numel() == pos.size(1));

  if (!optional_start.has_value())
    optional_start = std::get<0>(pos.min(0));
  else {
    optional_start = optional_start.value().contiguous();
    CHECK_INPUT(optional_start.value().numel() == pos.size(1));
  }

  if (!optional_end.has_value())
    optional_end = std::get<0>(pos.max(0));
  else {
    optional_start = optional_start.value().contiguous();
    CHECK_INPUT(optional_start.value().numel() == pos.size(1));
  }

  auto start = optional_start.value();
  auto end = optional_end.value();

  auto out = torch::empty({pos.size(0)}, pos.options().dtype(torch::kLong));

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, pos.scalar_type(), "_", [&] {
    grid_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
        pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
        start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),
        out.data_ptr<int64_t>(), pos.size(1), out.numel());
  });

  return out;
}