File: saint_cpu.cpp

package info (click to toggle)
pytorch-sparse 0.6.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 984 kB
  • sloc: python: 3,646; cpp: 2,444; sh: 54; makefile: 6
file content (49 lines) | stat: -rw-r--r-- 1,459 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include "saint_cpu.h"

#include "utils.h"

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
             torch::Tensor col) {
  CHECK_CPU(idx);
  CHECK_CPU(rowptr);
  CHECK_CPU(col);

  CHECK_INPUT(idx.dim() == 1);
  CHECK_INPUT(rowptr.dim() == 1);
  CHECK_INPUT(col.dim() == 1);

  auto assoc = torch::full({rowptr.size(0) - 1}, -1, idx.options());
  assoc.index_copy_(0, idx, torch::arange(idx.size(0), idx.options()));

  auto idx_data = idx.data_ptr<int64_t>();
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto assoc_data = assoc.data_ptr<int64_t>();

  std::vector<int64_t> rows, cols, indices;

  int64_t v, w, w_new, row_start, row_end;
  for (int64_t v_new = 0; v_new < idx.size(0); v_new++) {
    v = idx_data[v_new];
    row_start = rowptr_data[v];
    row_end = rowptr_data[v + 1];

    for (int64_t j = row_start; j < row_end; j++) {
      w = col_data[j];
      w_new = assoc_data[w];
      if (w_new > -1) {
        rows.push_back(v_new);
        cols.push_back(w_new);
        indices.push_back(j);
      }
    }
  }

  int64_t length = rows.size();
  row = torch::from_blob(rows.data(), {length}, row.options()).clone();
  col = torch::from_blob(cols.data(), {length}, row.options()).clone();
  idx = torch::from_blob(indices.data(), {length}, row.options()).clone();

  return std::make_tuple(row, col, idx);
}