File: ts_node.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (105 lines) | stat: -rw-r--r-- 2,925 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <torch/csrc/lazy/core/debug_util.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>

namespace {
std::string GetFirstUserFrameInPythonIfEnabled() {
  static const auto LTC_ENABLE_SOURCE_INFO =
      std::getenv("LTC_ENABLE_SOURCE_INFO");
  if (!LTC_ENABLE_SOURCE_INFO) {
    return {};
  }

  return torch::lazy::GetFirstUserFrameInPython();
}
} // namespace

namespace torch {
namespace lazy {

hash_t OperandHashes(
    const OpList& operands,
    const c10::ArrayRef<Shape>& shapes,
    const hash_t& seed,
    bool bakeInSizes) {
  hash_t hash = seed;
  for (auto& operand : operands) {
    if (!operand) {
      hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
      continue;
    }
    auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
    hash = HashCombine(hash, operand_hash);
  }
  for (auto& shape : shapes) {
    hash = HashCombine(hash, shape.hash(bakeInSizes));
  }
  return hash;
}

TsNode::TsNode(
    OpKind op,
    OpList operands,
    std::vector<Shape>&& shapes,
    size_t num_outputs,
    hash_t hash_seed)
    : Node(op, operands, std::move(shapes), num_outputs) {
  hash_seed = HashCombine(op.hash(), hash_seed);
  shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
  dag_hash_ =
      (enableDynamicShape()
           ? OperandHashes(operands, this->shapes(), hash_seed, false)
           : shape_hash_);
}

TsNode::TsNode(
    OpKind op,
    OpList operands,
    const std::function<Shape()>& shape_fn,
    size_t num_outputs,
    hash_t hash_seed)
    : TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
  addComputedShape(shape_fn);
}

TsNode::TsNode(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
    : TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}

TsNode::TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
    : TsNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}

hash_t TsNode::hash() const {
  return dag_hash_;
}

hash_t TsNode::shapeHash() const {
  return shape_hash_;
}

const std::string TsNode::getPythonStacktrace() const {
  return GetFirstUserFrameInPythonIfEnabled();
}

TensorList::TensorList(OpList values)
    : TsNode(
          /*op=*/ClassOpKind(),
          /*operands=*/values,
          /*shapes=*/std::vector<Shape>(),
          /*num_outputs=*/1,
          /*hash_seed=*/kHashSeed) {}

TSOpVector TensorList::Lower(
    std::shared_ptr<torch::jit::GraphFunction> function,
    TSLoweringContext* loctx) const {
  std::vector<torch::jit::Value*> tensor_list;
  CHECK(!operands().empty());
  for (const torch::lazy::Output& operand : operands()) {
    tensor_list.emplace_back(loctx->GetOutputOp(operand));
  }
  auto graph = function->graph();
  auto listnode =
      graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list));
  return {listnode->output()};
}

} // namespace lazy
} // namespace torch