File: knn_cpu.cpp

package info (click to toggle)
pytorch-cluster 1.6.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 648 kB
  • sloc: cpp: 2,076; python: 1,081; sh: 53; makefile: 8
file content (105 lines) | stat: -rw-r--r-- 3,404 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include "knn_cpu.h"

#include "utils.h"
#include "utils/KDTreeVectorOfVectorsAdaptor.h"
#include "utils/nanoflann.hpp"

torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
                      torch::optional<torch::Tensor> ptr_x,
                      torch::optional<torch::Tensor> ptr_y, int64_t k,
                      int64_t num_workers) {

  CHECK_CPU(x);
  CHECK_INPUT(x.dim() == 2);
  CHECK_CPU(y);
  CHECK_INPUT(y.dim() == 2);

  if (ptr_x.has_value()) {
    CHECK_CPU(ptr_x.value());
    CHECK_INPUT(ptr_x.value().dim() == 1);
  }
  if (ptr_y.has_value()) {
    CHECK_CPU(ptr_y.value());
    CHECK_INPUT(ptr_y.value().dim() == 1);
  }

  std::vector<size_t> out_vec = std::vector<size_t>();

  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "knn_cpu", [&] {
    // See: nanoflann/examples/vector_of_vectors_example.cpp

    auto x_data = x.data_ptr<scalar_t>();
    auto y_data = y.data_ptr<scalar_t>();
    typedef std::vector<std::vector<scalar_t>> vec_t;

    if (!ptr_x.has_value()) { // Single example.

      vec_t pts(x.size(0));
      for (int64_t i = 0; i < x.size(0); i++) {
        pts[i].resize(x.size(1));
        for (int64_t j = 0; j < x.size(1); j++) {
          pts[i][j] = x_data[i * x.size(1) + j];
        }
      }

      typedef KDTreeVectorOfVectorsAdaptor<vec_t, scalar_t> my_kd_tree_t;

      my_kd_tree_t mat_index(x.size(1), pts, 10);
      mat_index.index->buildIndex();

      std::vector<size_t> ret_index(k);
      std::vector<scalar_t> out_dist_sqr(k);
      for (int64_t i = 0; i < y.size(0); i++) {
        size_t num_matches = mat_index.index->knnSearch(
            y_data + i * y.size(1), k, &ret_index[0], &out_dist_sqr[0]);

        for (size_t j = 0; j < num_matches; j++) {
          out_vec.push_back(ret_index[j]);
          out_vec.push_back(i);
        }
      }
    } else { // Batch-wise.

      auto ptr_x_data = ptr_x.value().data_ptr<int64_t>();
      auto ptr_y_data = ptr_y.value().data_ptr<int64_t>();

      for (int64_t b = 0; b < ptr_x.value().size(0) - 1; b++) {
        auto x_start = ptr_x_data[b], x_end = ptr_x_data[b + 1];
        auto y_start = ptr_y_data[b], y_end = ptr_y_data[b + 1];

        if (x_start == x_end || y_start == y_end)
          continue;

        vec_t pts(x_end - x_start);
        for (int64_t i = 0; i < x_end - x_start; i++) {
          pts[i].resize(x.size(1));
          for (int64_t j = 0; j < x.size(1); j++) {
            pts[i][j] = x_data[(i + x_start) * x.size(1) + j];
          }
        }

        typedef KDTreeVectorOfVectorsAdaptor<vec_t, scalar_t> my_kd_tree_t;

        my_kd_tree_t mat_index(x.size(1), pts, 10);
        mat_index.index->buildIndex();

        std::vector<size_t> ret_index(k);
        std::vector<scalar_t> out_dist_sqr(k);
        for (int64_t i = y_start; i < y_end; i++) {
          size_t num_matches = mat_index.index->knnSearch(
              y_data + i * y.size(1), k, &ret_index[0], &out_dist_sqr[0]);

          for (size_t j = 0; j < num_matches; j++) {
            out_vec.push_back(x_start + ret_index[j]);
            out_vec.push_back(i);
          }
        }
      }
    }
  });

  const int64_t size = out_vec.size() / 2;
  auto out = torch::from_blob(out_vec.data(), {size, 2},
                              x.options().dtype(torch::kLong));
  return out.t().index_select(0, torch::tensor({1, 0}));
}