File: sampler_cpu.cpp

package info (click to toggle)
pytorch-cluster 1.6.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 648 kB
  • sloc: cpp: 2,076; python: 1,081; sh: 53; makefile: 8
file content (46 lines) | stat: -rw-r--r-- 1,604 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include "sampler_cpu.h"

#include "utils.h"

torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
                                   int64_t count, double factor) {

  auto start_data = start.data_ptr<int64_t>();
  auto rowptr_data = rowptr.data_ptr<int64_t>();

  std::vector<int64_t> e_ids;
  for (auto i = 0; i < start.size(0); i++) {
    auto row_start = rowptr_data[start_data[i]];
    auto row_end = rowptr_data[start_data[i] + 1];
    auto num_neighbors = row_end - row_start;

    int64_t size = count;
    if (count < 1)
      size = int64_t(ceil(factor * float(num_neighbors)));
    if (size > num_neighbors)
      size = num_neighbors;

    // If the number of neighbors is approximately equal to the number of
    // neighbors which are requested, we use `randperm` to sample without
    // replacement, otherwise we sample random numbers into a set as long
    // as necessary.
    std::unordered_set<int64_t> set;
    if (size < 0.7 * float(num_neighbors)) {
      while (int64_t(set.size()) < size) {
        int64_t sample = rand() % num_neighbors;
        set.insert(sample + row_start);
      }
      std::vector<int64_t> v(set.begin(), set.end());
      e_ids.insert(e_ids.end(), v.begin(), v.end());
    } else {
      auto sample = torch::randperm(num_neighbors, start.options());
      auto sample_data = sample.data_ptr<int64_t>();
      for (auto j = 0; j < size; j++) {
        e_ids.push_back(sample_data[j] + row_start);
      }
    }
  }

  int64_t length = e_ids.size();
  return torch::from_blob(e_ids.data(), {length}, start.options()).clone();
}