File: sparse_matrix_reshape_op.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (130 lines) | stat: -rw-r--r-- 4,305 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
/**
 * Copyright (c) 2016-present, Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_
#define CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_

#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

template <class Context>
class SparseMatrixReshapeOp : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  SparseMatrixReshapeOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws) {
    CAFFE_ENFORCE(
        OperatorBase::HasArgument("old_shape"),
        "Argument `old_shape` is missing.");
    CAFFE_ENFORCE(
        OperatorBase::HasArgument("new_shape"),
        "Argument `new_shape` is missing.");

    vector<int64_t> old_shape =
        OperatorBase::GetRepeatedArgument<int64_t>("old_shape");
    vector<int64_t> new_shape =
        OperatorBase::GetRepeatedArgument<int64_t>("new_shape");

    CAFFE_ENFORCE(
        old_shape.size() == 2,
        "Argument `old_shape` must contain exactly two integers.");
    CAFFE_ENFORCE(
        new_shape.size() == 2,
        "Argument `new_shape` must contain exactly two integers.");

    CAFFE_ENFORCE(
        old_shape[1] > 0,
        "The second dimension in argument `old_shape` must be positive.");

    old_stride_ = old_shape[1];

    if (old_shape[0] == -1) {
      CAFFE_ENFORCE(
          new_shape[1] > 0,
          "The second dimension in `new_shape` must be positive.");
    } else {
      CAFFE_ENFORCE(
          old_shape[0] > 0,
          "The first dimension in `old_shape` must be positive.");

      int64_t matrix_size = old_shape[0] * old_shape[1];

      if (new_shape[0] == -1) {
        CAFFE_ENFORCE(
            new_shape[1] > 0,
            "Only one dimension in argument `new_shape` can be -1.");
        CAFFE_ENFORCE(
            matrix_size % new_shape[1] == 0,
            "Argument `new_shape` does not agree with `old_shape`.");
      } else {
        CAFFE_ENFORCE(
            new_shape[0] > 0 && (new_shape[1] == -1 || new_shape[1] > 0),
            "Dimensions in argument `new_shape` must be positive or -1.");
        if (new_shape[1] == -1) {
          CAFFE_ENFORCE(
              matrix_size % new_shape[0] == 0,
              "Argument `new_shape` does not agree with `old_shape`.");
          new_shape[1] = matrix_size / new_shape[0];
        } else {
          CAFFE_ENFORCE(
              new_shape[0] * new_shape[1] == matrix_size,
              "Argument `new_shape` does not agree with `old_shape`.");
        }
      }
    }
    new_stride_ = new_shape[1];
  }

  bool RunOnDevice() override {
    auto& old_col = Input(0);
    CAFFE_ENFORCE(old_col.dim() == 1, "Row index tensor must be 1-D.");
    auto& old_row = Input(1);
    CAFFE_ENFORCE(old_row.dim() == 1, "Column index tensor must be 1-D.");

    const auto nnz = old_col.numel();
    CAFFE_ENFORCE(
        old_row.numel() == nnz,
        "Column and row tensors must have the same size.");

    auto* new_col = Output(0, {nnz}, at::dtype<int64_t>());
    auto* new_row = Output(1, {nnz}, at::dtype<int>());

    const auto* old_col_data = old_col.template data<int64_t>();
    const auto* old_row_data = old_row.template data<int>();

    auto* new_col_data = new_col->template mutable_data<int64_t>();
    auto* new_row_data = new_row->template mutable_data<int>();

    for (const auto i : c10::irange(nnz)) {
      int64_t offset = old_row_data[i] * old_stride_ + old_col_data[i];
      new_row_data[i] = offset / new_stride_;
      new_col_data[i] = offset % new_stride_;
    }

    return true;
  }

 private:
  int64_t old_stride_;
  int64_t new_stride_;
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_