File: ts_lowering_context.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (85 lines) | stat: -rw-r--r-- 2,996 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <c10/core/ScalarType.h>
#include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>

namespace torch {
namespace lazy {

TSLoweringContext::TSLoweringContext(
    const std::string& name,
    BackendDevice device)
    : torch::lazy::LoweringContext(name, device),
      graph_(std::make_shared<torch::jit::Graph>()),
      function_(
          std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {}

TSLoweringContext::TSLoweringContext(
    const std::string& name,
    BackendDevice device,
    c10::ArrayRef<Node*> post_order,
    Util::EmissionMap emit_status)
    : torch::lazy::LoweringContext(name, device, post_order, emit_status),
      graph_(std::make_shared<torch::jit::Graph>()),
      function_(
          std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {
  for (auto node : post_order) {
    Lower(node);
  }
}

void TSLoweringContext::Lower(const Node* node) {
  if (auto* tsnode = dynamic_cast<const torch::lazy::TsNode*>(node)) {
    // First, we call the node lowering function, which exists for newly
    // codegenned or refactored nodes
    TSOpVector ops = tsnode->Lower(function_, this);
    CHECK(!ops.empty()) << "Failed to lower: " << *node;
    TORCH_CHECK_EQ(node->num_outputs(), ops.size());
    for (size_t i = 0; i < ops.size(); ++i) {
      AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
    }
  } else {
    throw std::runtime_error(
        "Expected torch::lazy::TsNode but could not dynamic cast");
  }
}

void TSLoweringContext::AssignOutputOp(
    const Output& output,
    torch::jit::Value* op) {
  const TsNode* ts_node = static_cast<const TsNode*>(output.node);
  std::string stack_trace = ts_node->getPythonStacktrace();
  if (!stack_trace.empty()) {
    op->node()->s_(c10::Symbol::attr("source"), stack_trace);
  }
  emitted_outputs_[output] = op;
}

torch::jit::Value* TSLoweringContext::GetParameter(BackendDataPtr data) {
  const auto ts_data = std::static_pointer_cast<TSData>(data);
  BackendData::Handle handle = ts_data->GetHandle();
  auto it = parameters_map_.find(handle);
  if (it == parameters_map_.end()) {
    torch::jit::Value* param =
        graph_->addInput(c10::str("p", parameters_.size()));
    if (ts_data->scalar.has_value()) {
      auto scalarType = ts_data->scalar.value().type();
      if (isFloatingType(scalarType)) {
        param->setType(c10::FloatType::get());
      } else if (isIntegralType(scalarType, /*includeBool=*/true)) {
        param->setType(c10::IntType::get());
      } else {
        TORCH_CHECK(
            false, "Unhandled scalar type: ", c10::toString(scalarType));
      }
    }
    it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
             .first;
    parameters_.push_back(ts_data);
  }
  parameter_sequence_.push_back(it->second.index);
  return it->second.param;
}

} // namespace lazy
} // namespace torch