File: ts_lowering_context.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (90 lines) | stat: -rw-r--r-- 3,104 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <c10/core/ScalarType.h>
#include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>

#include <utility>

namespace torch::lazy {

TSLoweringContext::TSLoweringContext(
    const std::string& name,
    BackendDevice device)
    : torch::lazy::LoweringContext(name, std::move(device)),
      graph_(std::make_shared<torch::jit::Graph>()),
      function_(
          std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {}

TSLoweringContext::TSLoweringContext(
    const std::string& name,
    BackendDevice device,
    c10::ArrayRef<const Node*> post_order,
    Util::EmissionMap emit_status)
    : torch::lazy::LoweringContext(
          name,
          std::move(device),
          post_order,
          std::move(emit_status)),
      graph_(std::make_shared<torch::jit::Graph>()),
      function_(
          std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {
  for (auto node : post_order) {
    Lower(node);
  }
}

void TSLoweringContext::Lower(const Node* node) {
  if (auto* tsnode = dynamic_cast<const torch::lazy::TsNode*>(node)) {
    // First, we call the node lowering function, which exists for newly
    // codegenned or refactored nodes
    TSOpVector ops = tsnode->Lower(function_, this);
    TORCH_CHECK(!ops.empty(), "Failed to lower: ", *node);
    TORCH_CHECK_EQ(node->num_outputs(), ops.size());
    for (size_t i = 0; i < ops.size(); ++i) {
      AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
    }
  } else {
    throw std::runtime_error(
        "Expected torch::lazy::TsNode but could not dynamic cast");
  }
}

void TSLoweringContext::AssignOutputOp(
    const Output& output,
    torch::jit::Value* op) {
  const TsNode* ts_node = static_cast<const TsNode*>(output.node);
  std::string stack_trace = ts_node->getPythonStacktrace();
  if (!stack_trace.empty()) {
    op->node()->s_(c10::Symbol::attr("source"), stack_trace);
  }
  emitted_outputs_[output] = op;
}

torch::jit::Value* TSLoweringContext::GetParameter(const BackendDataPtr& data) {
  const auto ts_data = std::static_pointer_cast<TSData>(data);
  BackendData::Handle handle = ts_data->GetHandle();
  auto it = parameters_map_.find(handle);
  if (it == parameters_map_.end()) {
    torch::jit::Value* param =
        graph_->addInput(c10::str("p", parameters_.size()));
    const auto& scalar = ts_data->scalar;
    if (scalar.has_value()) {
      auto scalarType = scalar.value().type();
      if (isFloatingType(scalarType)) {
        param->setType(c10::FloatType::get());
      } else if (isIntegralType(scalarType, /*includeBool=*/true)) {
        param->setType(c10::IntType::get());
      } else {
        TORCH_CHECK(
            false, "Unhandled scalar type: ", c10::toString(scalarType));
      }
    }
    it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
             .first;
    parameters_.push_back(ts_data);
  }
  parameter_sequence_.push_back(it->second.index);
  return it->second.param;
}

} // namespace torch::lazy