File: nearest_cuda.cu

package info (click to toggle)
pytorch-cluster 1.6.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 648 kB
  • sloc: cpp: 2,076; python: 1,081; sh: 53; makefile: 8
file content (91 lines) | stat: -rw-r--r-- 2,532 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include "nearest_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

#define THREADS 1024

template <typename scalar_t>
__global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
                               const int64_t *ptr_x, const int64_t *ptr_y,
                               int64_t *out, int64_t batch_size, int64_t dim) {

  const int64_t thread_idx = threadIdx.x;
  const int64_t n_x = blockIdx.x;

  int64_t batch_idx;
  for (int64_t b = 0; b < batch_size; b++) {
    if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) {
      batch_idx = b;
      break;
    }
  }

  const int64_t y_start_idx = ptr_y[batch_idx];
  const int64_t y_end_idx = ptr_y[batch_idx + 1];

  __shared__ scalar_t best_dist[THREADS];
  __shared__ int64_t best_dist_idx[THREADS];

  scalar_t best = 1e38;
  int64_t best_idx = 0;
  for (int64_t n_y = y_start_idx + thread_idx; n_y < y_end_idx;
       n_y += THREADS) {
    scalar_t dist = 0;
    for (int64_t d = 0; d < dim; d++) {
      dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
              (x[n_x * dim + d] - y[n_y * dim + d]);
    }

    if (dist < best) {
      best = dist;
      best_idx = n_y;
    }
  }

  best_dist[thread_idx] = best;
  best_dist_idx[thread_idx] = best_idx;

  for (int64_t u = 0; (1 << u) < THREADS; u++) {
    __syncthreads();
    if (thread_idx < (THREADS >> (u + 1))) {
      int64_t idx_1 = (thread_idx * 2) << u;
      int64_t idx_2 = (thread_idx * 2 + 1) << u;
      if (best_dist[idx_1] > best_dist[idx_2]) {
        best_dist[idx_1] = best_dist[idx_2];
        best_dist_idx[idx_1] = best_dist_idx[idx_2];
      }
    }
  }

  __syncthreads();
  if (thread_idx == 0) {
    out[n_x] = best_dist_idx[0];
  }
}

torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
                           torch::Tensor ptr_x, torch::Tensor ptr_y) {
  CHECK_CUDA(x);
  CHECK_CUDA(y);
  CHECK_CUDA(ptr_x);
  CHECK_CUDA(ptr_y);
  cudaSetDevice(x.get_device());

  x = x.view({x.size(0), -1}).contiguous();
  y = y.view({y.size(0), -1}).contiguous();

  auto out = torch::empty({x.size(0)}, ptr_x.options());

  auto stream = at::cuda::getCurrentCUDAStream();
  auto scalar_type = x.scalar_type();
  AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
    nearest_kernel<scalar_t><<<x.size(0), THREADS, 0, stream>>>(
        x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
        ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
        out.data_ptr<int64_t>(), ptr_x.size(0) - 1, x.size(1));
  });

  return out;
}