File: dynamic_ir.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (114 lines) | stat: -rw-r--r-- 3,113 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>

static const torch::lazy::DimensionNode* DimCast(torch::lazy::Output output) {
  return dynamic_cast<const torch::lazy::DimensionNode*>(output.node);
}

namespace torch {
namespace lazy {

TSOpVector SizeNode::Lower(
    std::shared_ptr<torch::jit::GraphFunction> function,
    TSLoweringContext* loctx) const {
  std::vector<torch::jit::NamedValue> arguments;
  std::vector<torch::jit::NamedValue> kwarguments;
  arguments.reserve(2);
  auto index = loctx->graph()->insertConstant(static_cast<int64_t>(this->dim_));
  arguments.emplace_back(loctx->GetOutputOp(operand(0)));
  arguments.emplace_back(index);
  torch::lazy::TSOpVector size_out =
      torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
  TORCH_CHECK_EQ(size_out.size(), 1);
  return size_out;
}

SizeNode::SizeNode(Value input, size_t dim)
    : TsNode(
          OpKind{c10::Symbol::fromQualString("aten::size")},
          {input},
          std::vector<Shape>{},
          1,
          MHash(dim)),
      dim_(dim){};

int64_t SizeNode::getStaticValue() const {
  return dynamic_cast<const TsNode*>(operand(0).node)->shape(0).size(dim_);
}
bool SizeNode::isSymbolic() const {
  auto symbolic_vec =
      dynamic_cast<const TsNode*>(operand(0).node)->shape(0).is_symbolic();
  if (!symbolic_vec.has_value()) {
    return true;
  }
  return symbolic_vec->at(dim_);
}

std::string SizeNode::ToString() const {
  return "SizeNode";
}

SizeAdd::SizeAdd(Value a, Value b)
    : TsNode(
          OpKind{c10::Symbol::fromQualString("aten::add")},
          {a, b},
          std::vector<Shape>{},
          1){};

int64_t SizeAdd::getStaticValue() const {
  return DimCast(operand(0))->getStaticValue() +
      DimCast(operand(1))->getStaticValue();
}

bool SizeAdd::isSymbolic() const {
  return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
}

std::string SizeAdd::ToString() const {
  return "SizeAdd";
}

SizeMul::SizeMul(Value a, Value b)
    : TsNode(
          OpKind{c10::Symbol::fromQualString("aten::mul")},
          {a, b},
          std::vector<Shape>{},
          1){};

int64_t SizeMul::getStaticValue() const {
  return DimCast(operand(0))->getStaticValue() *
      DimCast(operand(1))->getStaticValue();
}

bool SizeMul::isSymbolic() const {
  return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
}

std::string SizeMul::ToString() const {
  return "SizeMul";
}

SizeDiv::SizeDiv(Value a, Value b)
    : TsNode(
          OpKind{c10::Symbol::fromQualString("aten::div")},
          {a, b},
          std::vector<Shape>{},
          1){};

int64_t SizeDiv::getStaticValue() const {
  TORCH_CHECK(
      DimCast(operand(1))->getStaticValue() != 0,
      "Can't divide a dimension by zero");
  return DimCast(operand(0))->getStaticValue() /
      DimCast(operand(1))->getStaticValue();
}

bool SizeDiv::isSymbolic() const {
  return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
}

std::string SizeDiv::ToString() const {
  return "SizeDiv";
}

} // namespace lazy
} // namespace torch