1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
#include "rw_cpu.h"
#include <ATen/Parallel.h>
#include "utils.h"
void uniform_sampling(const int64_t *rowptr, const int64_t *col,
const int64_t *start, int64_t *n_out, int64_t *e_out,
const int64_t numel, const int64_t walk_length) {
auto rand = torch::rand({numel, walk_length});
auto rand_data = rand.data_ptr<float>();
int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
for (auto n = begin; n < end; n++) {
int64_t n_cur = start[n], e_cur, row_start, row_end, idx;
n_out[n * (walk_length + 1)] = n_cur;
for (auto l = 0; l < walk_length; l++) {
row_start = rowptr[n_cur], row_end = rowptr[n_cur + 1];
if (row_end - row_start == 0) {
e_cur = -1;
} else {
idx = int64_t(rand_data[n * walk_length + l] * (row_end - row_start));
e_cur = row_start + idx;
n_cur = col[e_cur];
}
n_out[n * (walk_length + 1) + (l + 1)] = n_cur;
e_out[n * walk_length + l] = e_cur;
}
}
});
}
bool inline is_neighbor(const int64_t *rowptr, const int64_t *col, int64_t v,
int64_t w) {
int64_t row_start = rowptr[v], row_end = rowptr[v + 1];
for (auto i = row_start; i < row_end; i++) {
if (col[i] == w)
return true;
}
return false;
}
// See: https://louisabraham.github.io/articles/node2vec-sampling.html
void rejection_sampling(const int64_t *rowptr, const int64_t *col,
int64_t *start, int64_t *n_out, int64_t *e_out,
const int64_t numel, const int64_t walk_length,
const double p, const double q) {
double max_prob = fmax(fmax(1. / p, 1.), 1. / q);
double prob_0 = 1. / p / max_prob;
double prob_1 = 1. / max_prob;
double prob_2 = 1. / q / max_prob;
int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
for (auto n = begin; n < end; n++) {
int64_t t = start[n], v, x, e_cur, row_start, row_end;
n_out[n * (walk_length + 1)] = t;
row_start = rowptr[t], row_end = rowptr[t + 1];
if (row_end - row_start == 0) {
e_cur = -1;
v = t;
} else {
e_cur = row_start + (rand() % (row_end - row_start));
v = col[e_cur];
}
n_out[n * (walk_length + 1) + 1] = v;
e_out[n * walk_length] = e_cur;
for (auto l = 1; l < walk_length; l++) {
row_start = rowptr[v], row_end = rowptr[v + 1];
if (row_end - row_start == 0) {
e_cur = -1;
x = v;
} else if (row_end - row_start == 1) {
e_cur = row_start;
x = col[e_cur];
} else {
while (true) {
e_cur = row_start + (rand() % (row_end - row_start));
x = col[e_cur];
auto r = ((double)rand() / (RAND_MAX)); // [0, 1)
if (x == t && r < prob_0)
break;
else if (is_neighbor(rowptr, col, x, t) && r < prob_1)
break;
else if (r < prob_2)
break;
}
}
n_out[n * (walk_length + 1) + (l + 1)] = x;
e_out[n * walk_length + l] = e_cur;
t = v;
v = x;
}
}
});
}
std::tuple<torch::Tensor, torch::Tensor>
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(start);
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto n_out = torch::empty({start.size(0), walk_length + 1}, start.options());
auto e_out = torch::empty({start.size(0), walk_length}, start.options());
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto start_data = start.data_ptr<int64_t>();
auto n_out_data = n_out.data_ptr<int64_t>();
auto e_out_data = e_out.data_ptr<int64_t>();
if (p == 1. && q == 1.) {
uniform_sampling(rowptr_data, col_data, start_data, n_out_data, e_out_data,
start.numel(), walk_length);
} else {
rejection_sampling(rowptr_data, col_data, start_data, n_out_data,
e_out_data, start.numel(), walk_length, p, q);
}
return std::make_tuple(n_out, e_out);
}
|