File: grid_cpu.cpp

package info (click to toggle)
pytorch-cluster 1.6.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 648 kB
  • sloc: cpp: 2,076; python: 1,081; sh: 53; makefile: 8
file content (46 lines) | stat: -rw-r--r-- 1,329 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include "grid_cpu.h"

#include "utils.h"

torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
                       torch::optional<torch::Tensor> optional_start,
                       torch::optional<torch::Tensor> optional_end) {

  CHECK_CPU(pos);
  CHECK_CPU(size);

  if (optional_start.has_value())
    CHECK_CPU(optional_start.value());
  if (optional_start.has_value())
    CHECK_CPU(optional_start.value());

  pos = pos.view({pos.size(0), -1});
  CHECK_INPUT(size.numel() == pos.size(1));

  if (!optional_start.has_value())
    optional_start = std::get<0>(pos.min(0));
  else
    CHECK_INPUT(optional_start.value().numel() == pos.size(1));

  if (!optional_end.has_value())
    optional_end = std::get<0>(pos.max(0));
  else
    CHECK_INPUT(optional_start.value().numel() == pos.size(1));

  auto start = optional_start.value();
  auto end = optional_end.value();

  pos = pos - start.unsqueeze(0);

  auto num_voxels = (end - start).true_divide(size).toType(torch::kLong) + 1;
  num_voxels = num_voxels.cumprod(0);
  num_voxels =
      torch::cat({torch::ones({1}, num_voxels.options()), num_voxels}, 0);
  num_voxels = num_voxels.narrow(0, 0, size.size(0));

  auto out = pos.true_divide(size.view({1, -1})).toType(torch::kLong);
  out *= num_voxels.view({1, -1});
  out = out.sum(1);

  return out;
}