File: length_split_op.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (37 lines) | stat: -rw-r--r-- 1,194 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "caffe2/operators/length_split_op.h"

namespace caffe2 {

REGISTER_CPU_OPERATOR(LengthsSplit, LengthsSplitOp<CPUContext>);

OPERATOR_SCHEMA(LengthsSplit)
    .NumInputs(1, 2)
    .NumOutputs(1)
    .ScalarType(TensorProto::INT32)
    .SetDoc(R"DOC(
Given input vector LENGTHS, and input n_split, LengthsSplit returns
a single output vector. It "splits" each length into n_split values which add
up to the original length. It will attempt to do equal splits, and if not possible,
it orders larger values first. If the n_split is larger than the length, zero
padding will be applied.

e.g. LENGTHS = [9 4 5]
     n_split = 3
     Y = [3 3 3 2 1 1 2 2 1]

e.g. LENGTHS = [2, 1, 2]
     n_split = 3
     Y = [1 1 0 1 0 0 1 1 0]
)DOC")
    .Arg("n_split", "Number of splits for each element in LENGTHS")
    .Input(0, "LENGTHS", "Mx1 Input tensor denoting INT32 lengths")
    .Input(
        1,
        "n_split",
        "(Optional) Number of splits for each element in LENGTHS (overrides argument)")
    .Output(0, "Y", "(M*n_split)x1 Output vector denoting split lengths");

// TODO: Write gradient for this when needed
GRADIENT_NOT_IMPLEMENTED_YET(LengthsSplit);

} // namespace caffe2