File: rw_cuda.cu

package info (click to toggle)
pytorch-sparse 0.6.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 984 kB
  • sloc: python: 3,646; cpp: 2,444; sh: 54; makefile: 6
file content (53 lines) | stat: -rw-r--r-- 1,854 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include "rw_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
                                           const int64_t *col,
                                           const int64_t *start,
                                           const float *rand, int64_t *out,
                                           int64_t walk_length, int64_t numel) {
  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

  if (thread_idx < numel) {
    int64_t cur = start[thread_idx];
    out[thread_idx] = cur;

    int64_t row_start, row_end;
    for (int64_t l = 0; l < walk_length; l++) {
      row_start = rowptr[cur], row_end = rowptr[cur + 1];
      cur = col[row_start +
                int64_t(rand[l * numel + thread_idx] * (row_end - row_start))];
      out[(l + 1) * numel + thread_idx] = cur;
    }
  }
}

torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
                               torch::Tensor start, int64_t walk_length) {
  CHECK_CUDA(rowptr);
  CHECK_CUDA(col);
  CHECK_CUDA(start);
  cudaSetDevice(rowptr.get_device());

  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  CHECK_INPUT(start.dim() == 1);

  auto rand = torch::rand({walk_length, start.size(0)},
                          start.options().dtype(torch::kFloat));
  auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options());

  auto stream = at::cuda::getCurrentCUDAStream();
  uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
      rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
      start.data_ptr<int64_t>(), rand.data_ptr<float>(),
      out.data_ptr<int64_t>(), walk_length, start.numel());

  return out.t().contiguous();
}