File: radius_cpu.cpp

package info (click to toggle)
pytorch-cluster 1.6.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 648 kB
  • sloc: cpp: 2,076; python: 1,081; sh: 53; makefile: 8
file content (108 lines) | stat: -rw-r--r-- 3,572 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include "radius_cpu.h"

#include "utils.h"
#include "utils/KDTreeVectorOfVectorsAdaptor.h"
#include "utils/nanoflann.hpp"

torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
                         torch::optional<torch::Tensor> ptr_x,
                         torch::optional<torch::Tensor> ptr_y, double r,
                         int64_t max_num_neighbors, int64_t num_workers) {

  CHECK_CPU(x);
  CHECK_INPUT(x.dim() == 2);
  CHECK_CPU(y);
  CHECK_INPUT(y.dim() == 2);

  if (ptr_x.has_value()) {
    CHECK_CPU(ptr_x.value());
    CHECK_INPUT(ptr_x.value().dim() == 1);
  }
  if (ptr_y.has_value()) {
    CHECK_CPU(ptr_y.value());
    CHECK_INPUT(ptr_y.value().dim() == 1);
  }

  std::vector<size_t> out_vec = std::vector<size_t>();

  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "radius_cpu", [&] {
    // See: nanoflann/examples/vector_of_vectors_example.cpp

    auto x_data = x.data_ptr<scalar_t>();
    auto y_data = y.data_ptr<scalar_t>();
    typedef std::vector<std::vector<scalar_t>> vec_t;
    nanoflann::SearchParams params;
    params.sorted = false;

    if (!ptr_x.has_value()) { // Single example.

      vec_t pts(x.size(0));
      for (int64_t i = 0; i < x.size(0); i++) {
        pts[i].resize(x.size(1));
        for (int64_t j = 0; j < x.size(1); j++) {
          pts[i][j] = x_data[i * x.size(1) + j];
        }
      }

      typedef KDTreeVectorOfVectorsAdaptor<vec_t, scalar_t> my_kd_tree_t;

      my_kd_tree_t mat_index(x.size(1), pts, 10);
      mat_index.index->buildIndex();

      for (int64_t i = 0; i < y.size(0); i++) {
        std::vector<std::pair<size_t, scalar_t>> ret_matches;
        size_t num_matches = mat_index.index->radiusSearch(
            y_data + i * y.size(1), r * r, ret_matches, params);

        for (size_t j = 0; j < std::min(num_matches, (size_t)max_num_neighbors);
             j++) {
          out_vec.push_back(ret_matches[j].first);
          out_vec.push_back(i);
        }
      }

    } else { // Batch-wise.

      auto ptr_x_data = ptr_x.value().data_ptr<int64_t>();
      auto ptr_y_data = ptr_y.value().data_ptr<int64_t>();

      for (int64_t b = 0; b < ptr_x.value().size(0) - 1; b++) {
        auto x_start = ptr_x_data[b], x_end = ptr_x_data[b + 1];
        auto y_start = ptr_y_data[b], y_end = ptr_y_data[b + 1];

        if (x_start == x_end || y_start == y_end)
          continue;

        vec_t pts(x_end - x_start);
        for (int64_t i = 0; i < x_end - x_start; i++) {
          pts[i].resize(x.size(1));
          for (int64_t j = 0; j < x.size(1); j++) {
            pts[i][j] = x_data[(i + x_start) * x.size(1) + j];
          }
        }

        typedef KDTreeVectorOfVectorsAdaptor<vec_t, scalar_t> my_kd_tree_t;

        my_kd_tree_t mat_index(x.size(1), pts, 10);
        mat_index.index->buildIndex();

        for (int64_t i = y_start; i < y_end; i++) {
          std::vector<std::pair<size_t, scalar_t>> ret_matches;
          size_t num_matches = mat_index.index->radiusSearch(
              y_data + i * y.size(1), r * r, ret_matches, params);

          for (size_t j = 0;
               j < std::min(num_matches, (size_t)max_num_neighbors); j++) {
            out_vec.push_back(x_start + ret_matches[j].first);
            out_vec.push_back(i);
          }
        }
      }
    }
  });

  const int64_t size = out_vec.size() / 2;
  auto out = torch::from_blob(out_vec.data(), {size, 2},
                              x.options().dtype(torch::kLong));
  return out.t().index_select(0, torch::tensor({1, 0}));
}