File: ir_builder.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (153 lines) | stat: -rw-r--r-- 5,294 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#pragma once

#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/shape_inference.h>
#include <torch/csrc/lazy/generated/LazyNonNativeIr.h>
#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
#include <torch/csrc/lazy/ts_backend/ops/device_data.h>
#include <torch/csrc/lazy/ts_backend/ops/generic.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>

namespace torch {
namespace lazy {

struct TorchScriptIrBuilder : IrBuilder {
  NodePtr MakeDeviceData(
      const std::shared_ptr<BackendData>& data) const override {
    return DeviceData::Create(data);
  }
  // TODO: Scalar node is not currently used by ts_backend. Enable reusing
  // Scalar node later if needed.
  NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type)
      const override {
    return MakeNode<Scalar>(value, type);
  }
  NodePtr MakeExpand(
      const Value& input0,
      const std::vector<int64_t>& size,
      const bool& is_scalar_expand) const override {
    return ReuseOrMakeNode<Expand>(input0, size, is_scalar_expand);
  }
  NodePtr MakeView(const Value& input0, const std::vector<int64_t>& output_size)
      const override {
    return ReuseOrMakeNode<View>(input0, output_size);
  }
  NodePtr MakeCast(
      const Value& input0,
      const at::ScalarType& dtype,
      const c10::optional<at::ScalarType>& stype =
          c10::nullopt) const override {
    return ReuseOrMakeNode<Cast>(input0, dtype, stype);
  }
  NodePtr MakeTensorList(const OpList& inputs) const override {
    return ReuseOrMakeNode<TensorList>(inputs);
  }
  // Generic needs cleanup
  NodePtr MakeGeneric(
      const OpKind& op,
      const OpList& operands,
      const Shape& shape,
      const size_t& num_outputs = 1,
      const hash_t& hash_seed =
          static_cast<uint32_t>(0x5a2d296e9)) const override {
    return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed);
  }

  // View op nodes
  NodePtr MakeAsStridedViewUpdate(
      const Value& input0,
      const Value& input1,
      const std::vector<int64_t>& size,
      const std::vector<int64_t>& stride,
      const int64_t& storage_offset) const override {
    return ReuseOrMakeNode<AsStridedViewUpdate>(
        input0, input1, size, stride, storage_offset);
  }
  NodePtr MakeAsStrided(
      const Value& input0,
      const std::vector<int64_t>& size,
      const std::vector<int64_t>& stride,
      const int64_t& storage_offset) const override {
    return ReuseOrMakeNode<AsStrided>(input0, size, stride, storage_offset);
  }
  NodePtr MakeDiagonalViewUpdate(
      const Value& input0,
      const Value& input1,
      const int64_t& offset,
      const int64_t& dim1,
      const int64_t& dim2) const override {
    return ReuseOrMakeNode<DiagonalViewUpdate>(
        input0, input1, offset, dim1, dim2);
  }
  NodePtr MakeDiagonal(
      const Value& input0,
      const int64_t& offset,
      const int64_t& dim1,
      const int64_t& dim2) const override {
    return ReuseOrMakeNode<Diagonal>(input0, offset, dim1, dim2);
  }
  NodePtr MakeNarrowViewUpdate(
      const Value& input0,
      const Value& input1,
      const std::vector<int64_t>& base_indices) const override {
    return ReuseOrMakeNode<NarrowViewUpdate>(input0, input1, base_indices);
  }
  NodePtr MakeNarrow(
      const Value& input0,
      const std::vector<int64_t>& base_indices,
      const std::vector<int64_t>& sizes) const override {
    return ReuseOrMakeNode<Narrow>(input0, base_indices, sizes);
  }
  NodePtr MakePermute(const Value& input0, const std::vector<int64_t>& dims)
      const override {
    return ReuseOrMakeNode<Permute>(input0, dims);
  }
  NodePtr MakeResize(const Value& input0, const std::vector<int64_t>& size)
      const override {
    return ReuseOrMakeNode<Resize>(input0, size);
  }
  NodePtr MakeSelectViewUpdate(
      const Value& input0,
      const Value& input1,
      const int64_t& dim,
      const int64_t& start,
      const int64_t& end,
      const int64_t& stride) const override {
    return ReuseOrMakeNode<SelectViewUpdate>(
        input0, input1, dim, start, end, stride);
  }
  NodePtr MakeSelect(
      const Value& input0,
      const int64_t& dim,
      const int64_t& start,
      const int64_t& end,
      const int64_t& stride) const override {
    return ReuseOrMakeNode<Select>(input0, dim, start, end, stride);
  }
  NodePtr MakeSqueeze(const Value& input0, const int& dim) const override {
    return ReuseOrMakeNode<Squeeze>(input0, dim);
  }
  NodePtr MakeUnsqueeze(const Value& input0, const int& dim) const override {
    return ReuseOrMakeNode<Unsqueeze>(input0, dim);
  }

  // dynamic ir nodes
  // TODO: verify if IR node reusing works for Dynamic shape ops
  NodePtr MakeSizeNode(const Value& input, size_t dim) const override {
    return MakeNode<SizeNode>(input, dim);
  }
  NodePtr MakeSizeAdd(const Value& a, const Value& b) const override {
    return MakeNode<SizeAdd>(a, b);
  }
  NodePtr MakeSizeMul(const Value& a, const Value& b) const override {
    return MakeNode<SizeMul>(a, b);
  }
  NodePtr MakeSizeDiv(const Value& a, const Value& b) const override {
    return MakeNode<SizeDiv>(a, b);
  }
};

} // namespace lazy
} // namespace torch