File: decomposition_registry_util.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (77 lines) | stat: -rw-r--r-- 2,122 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

/**
 * @generated
 * This is an auto-generated file. Please do not modify it by hand.
 * To re-generate, please run:
 * cd ~/pytorch && python torchgen/decompositions/gen_jit_decompositions.py
 */
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/decomposition_registry_util.h>
#include <torch/csrc/jit/runtime/operator.h>

namespace torch {
namespace jit {

const std::string decomp_funcs =
    R"(def var_decomposition(input: Tensor,
    dim: Optional[List[int]]=None,
    correction: Optional[int]=None,
    keepdim: bool=False) -> Tensor:
  if torch.__is__(dim, None):
    dim0 = annotate(List[int], [])
  else:
    dim0 = unchecked_cast(List[int], dim)
  if torch.eq(torch.len(dim0), 0):
    n = torch.numel(input)
  else:
    n0 = 1
    for _0 in range(torch.len(dim0)):
      dim_i = dim0[_0]
      n1 = torch.mul(n0, (torch.size(input))[dim_i])
      n0 = n1
    n = n0
  mean = torch.mean(input, dim0, True)
  sub = torch.sub(input, mean)
  sq = torch.mul(sub, sub)
  sum = torch.sum(sq, dim0, keepdim)
  if torch.__isnot__(correction, None):
    correction0 = unchecked_cast(int, correction)
    n2 = torch.sub(n, correction0)
  else:
    n2 = n
  return torch.div(sum, n2)

def var(input: Tensor,
    unbiased: bool=True) -> Tensor:
  if unbiased:
    _0 = 1
  else:
    _0 = 0
  n = torch.numel(input)
  mean = torch.mean(input, annotate(List[int], []), True)
  sub = torch.sub(input, mean)
  sq = torch.mul(sub, sub)
  sum = torch.sum(sq, annotate(List[int], []))
  n0 = torch.sub(n, _0)
  return torch.div(sum, n0)

)";

const std::string& GetSerializedDecompositions() {
  return decomp_funcs;
}

const OperatorMap<std::string>& GetDecompositionMapping() {
  // clang-format off
 static const OperatorMap<std::string> decomposition_mapping {
    {"aten::var.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor)", "var_decomposition"},
    {"aten::var(Tensor self, bool unbiased=True) -> (Tensor)", "var"},
  };
  // clang-format on

  return decomposition_mapping;
}

} // namespace jit
} // namespace torch