File: rw_cpu.cpp

package info (click to toggle)
pytorch-sparse 0.6.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 984 kB
  • sloc: python: 3,646; cpp: 2,444; sh: 54; makefile: 6
file content (43 lines) | stat: -rw-r--r-- 1,248 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include "rw_cpu.h"

#include "utils.h"

torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
                              torch::Tensor start, int64_t walk_length) {
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  CHECK_CPU(start);

  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);
  CHECK_INPUT(start.dim() == 1);

  auto rand = torch::rand({start.size(0), walk_length},
                          start.options().dtype(torch::kFloat));

  auto L = walk_length + 1;
  auto out = torch::full({start.size(0), L}, -1, start.options());

  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto start_data = start.data_ptr<int64_t>();
  auto rand_data = rand.data_ptr<float>();
  auto out_data = out.data_ptr<int64_t>();

  for (auto n = 0; n < start.size(0); n++) {
    auto cur = start_data[n];
    out_data[n * L] = cur;

    int64_t row_start, row_end;
    for (auto l = 0; l < walk_length; l++) {
      row_start = rowptr_data[cur];
      row_end = rowptr_data[cur + 1];

      cur = col_data[row_start + int64_t(rand_data[n * walk_length + l] *
                                         (row_end - row_start))];
      out_data[n * L + l + 1] = cur;
    }
  }

  return out;
}