1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
|
#include "graclus_cpu.h"
#include "utils.h"
torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1);
if (optional_weight.has_value()) {
CHECK_CPU(optional_weight.value());
CHECK_INPUT(optional_weight.value().dim() == 1);
CHECK_INPUT(optional_weight.value().numel() == col.numel());
}
int64_t num_nodes = rowptr.numel() - 1;
auto out = torch::full(num_nodes, -1, rowptr.options());
auto node_perm = torch::randperm(num_nodes, rowptr.options());
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto node_perm_data = node_perm.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
if (!optional_weight.has_value()) {
for (int64_t n = 0; n < num_nodes; n++) {
auto u = node_perm_data[n];
if (out_data[u] >= 0)
continue;
out_data[u] = u;
int64_t row_start = rowptr_data[u], row_end = rowptr_data[u + 1];
for (auto e = 0; e < row_end - row_start; e++) {
auto v = col_data[row_start + e];
if (out_data[v] >= 0)
continue;
out_data[u] = std::min(u, v);
out_data[v] = std::min(u, v);
break;
}
}
} else {
auto weight = optional_weight.value();
auto scalar_type = weight.scalar_type();
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "graclus_cpu", [&] {
auto weight_data = weight.data_ptr<scalar_t>();
for (auto n = 0; n < num_nodes; n++) {
auto u = node_perm_data[n];
if (out_data[u] >= 0)
continue;
auto v_max = u;
scalar_t w_max = (scalar_t)0.;
for (auto e = rowptr_data[u]; e < rowptr_data[u + 1]; e++) {
auto v = col_data[e];
if (out_data[v] >= 0)
continue;
if (weight_data[e] >= w_max) {
v_max = v;
w_max = weight_data[e];
}
}
out_data[u] = std::min(u, v_max);
out_data[v_max] = std::min(u, v_max);
}
});
}
return out;
}
|