File: index_info.h

package info (click to toggle)
pytorch-scatter 2.1.2-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,128 kB
  • sloc: python: 1,574; cpp: 1,379; sh: 58; makefile: 13
file content (63 lines) | stat: -rw-r--r-- 1,655 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#pragma once

#include "../extensions.h"

#define MAX_TENSORINFO_DIMS 25

template <typename scalar_t> struct TensorInfo {
  TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS],
             int st[MAX_TENSORINFO_DIMS]) {
    data = p;
    dims = dim;
    AT_ASSERT(dims < MAX_TENSORINFO_DIMS);

    for (int i = 0; i < dim; ++i) {
      sizes[i] = sz[i];
      strides[i] = st[i];
    }
  }

  scalar_t *data;
  int dims;
  int sizes[MAX_TENSORINFO_DIMS];
  int strides[MAX_TENSORINFO_DIMS];
};

template <typename scalar_t>
TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) {
  int sizes[MAX_TENSORINFO_DIMS];
  int strides[MAX_TENSORINFO_DIMS];

  int dims = tensor.dim();
  for (int i = 0; i < dims; ++i) {
    sizes[i] = tensor.size(i);
    strides[i] = tensor.stride(i);
  }

  return TensorInfo<scalar_t>(tensor.data_ptr<scalar_t>(), dims, sizes,
                              strides);
}

template <typename scalar_t> struct IndexToOffset {
  static inline int get(int idx, const TensorInfo<scalar_t> &info) {
    int offset = 0;
    for (int i = info.dims - 1; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};

template <typename scalar_t> struct IndexPtrToOffset {
  static inline int get(int idx, const TensorInfo<scalar_t> &info) {
    int offset = idx % (info.sizes[info.dims - 1] - 1);
    offset *= info.strides[info.dims - 1];
    idx /= info.sizes[info.dims - 1] - 1;
    for (int i = info.dims - 2; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};