File: ir_builder.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (148 lines) | stat: -rw-r--r-- 4,709 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#pragma once

#include <c10/core/ScalarType.h>
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/config.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/tensor.h>
#include <torch/csrc/lazy/core/trie.h>
#include <optional>
#include <vector>

// This file is part of the backend interface. So, ops shouldn't be added or
// removed without due process The exception to this being the view ops which
// will be removed soon pending functionalization

namespace torch::lazy {

template <typename T, typename... Args>
NodePtr ReuseNode(Args&&... args) {
  if (FLAGS_torch_lazy_reuse_ir) {
    return LookupNodeFromTrieCache<T>(std::forward<Args>(args)...);
  }
  return nullptr;
}

// Caching an IR node into TrieCache
static inline void CacheNode(NodePtr node) {
  if (FLAGS_torch_lazy_reuse_ir) {
    TrieCache::Get()->Insert(std::move(node));
  }
}

template <typename T, typename... Args>
NodePtr MakeNode(Args&&... args) {
  return std::make_shared<T>(std::forward<Args>(args)...);
}

// op is passed in for a more efficient node casting, see the implementation of
// NodeCast
template <typename T, typename... Args>
NodePtr ReuseOrMakeNode(Args&&... args) {
  NodePtr node = ReuseNode<T>(std::forward<Args>(args)...);
  if (!node) {
    node = MakeNode<T>(std::forward<Args>(args)...);
    CacheNode(node);
  }
  return node;
}

struct IrBuilder {
  virtual NodePtr MakeDeviceData(
      const std::shared_ptr<BackendData>& data) const = 0;
  virtual NodePtr MakeScalar(
      const at::Scalar& value,
      const at::ScalarType& type) const = 0;
  virtual NodePtr MakeExpand(
      const Value& input0,
      const std::vector<int64_t>& size,
      const bool& is_scalar_expand) const = 0;
  virtual NodePtr MakeCast(
      const Value& input0,
      const at::ScalarType& dtype,
      const std::optional<at::ScalarType>& stype = std::nullopt) const = 0;
  virtual NodePtr MakeTensorList(const OpList& inputs) const = 0;
  virtual NodePtr MakeGeneric(
      const OpKind& op,
      const OpList& operands,
      const Shape& shape,
      const size_t& num_outputs = 1,
      const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) const = 0;

  // dynamic ir nodes
  virtual NodePtr MakeSizeNode(const Value& input, size_t dim) const = 0;
  virtual NodePtr MakeSizeAdd(const Value& a, const Value& b) const = 0;
  virtual NodePtr MakeSizeMul(const Value& a, const Value& b) const = 0;
  virtual NodePtr MakeSizeDiv(const Value& a, const Value& b) const = 0;

  virtual ~IrBuilder() = default;
};

static inline NodePtr MakeDeviceData(const std::shared_ptr<BackendData>& data) {
  return getIrBuilder()->MakeDeviceData(data);
}
static inline NodePtr MakeScalar(
    const at::Scalar& value,
    const at::ScalarType& type) {
  return getIrBuilder()->MakeScalar(value, type);
}
static inline NodePtr MakeExpand(
    const Value& input0,
    const std::vector<int64_t>& size,
    const bool& is_scalar_expand) {
  return getIrBuilder()->MakeExpand(input0, size, is_scalar_expand);
}
static inline NodePtr MakeCast(
    const Value& input0,
    const at::ScalarType& dtype,
    const std::optional<at::ScalarType>& stype = std::nullopt) {
  return getIrBuilder()->MakeCast(input0, dtype, stype);
}
static inline NodePtr MakeTensorList(const OpList& inputs) {
  return getIrBuilder()->MakeTensorList(inputs);
}
static inline NodePtr MakeGeneric(
    const OpKind& op,
    const OpList& operands,
    const Shape& shape,
    const size_t& num_outputs = 1,
    const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) {
  return getIrBuilder()->MakeGeneric(
      op, operands, shape, num_outputs, hash_seed);
}

// dynamic ir nodes
static inline NodePtr MakeSizeNode(const Value& input, size_t dim) {
  return getIrBuilder()->MakeSizeNode(input, dim);
}
static inline NodePtr MakeSizeAdd(const Value& a, const Value& b) {
  return getIrBuilder()->MakeSizeAdd(a, b);
}
static inline NodePtr MakeSizeMul(const Value& a, const Value& b) {
  return getIrBuilder()->MakeSizeAdd(a, b);
}
static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
  return getIrBuilder()->MakeSizeDiv(a, b);
}

inline Value GetSymIntValue(const c10::SymInt& a) {
  if (auto ma = a.maybe_as_int()) {
    return Value(MakeScalar(*ma, at::kLong), 0);
  } else {
    return Value(
        dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImplUnowned())
            ->node_,
        0);
  }
}

// TODO: this should return Value
inline std::vector<int64_t> GetSymIntArrayRefValue(c10::SymIntArrayRef arr) {
  std::vector<int64_t> r;
  for (const auto& a : arr) {
    r.emplace_back(a.guard_int(__FILE__, __LINE__));
  }
  return r;
}

} // namespace torch::lazy