File: ir_cloner.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (130 lines) | stat: -rw-r--r-- 3,811 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#pragma once

#include <c10/macros/Export.h>
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>

#include <unordered_map>
#include <vector>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class IrContainer;

//! Clones nodes from an exiting Fusion
//!
//! \warning IrCloner machinery is a specialized helper for implementing
//!   Fusion copy operations and the and limited scope of RecomputeTv below.
//!   It is not intended for any other uses.
//!
class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch {
  friend class Statement;
  friend class IrBuilder;

 public:
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  explicit IrCloner(IrContainer* container);

  Statement* clone(const Statement* statement);

  template <class T>
  T* clone(const T* node) {
    return node ? clone(node->template as<Statement>())->template as<T>()
                : nullptr;
  }

  template <class T>
  std::vector<T*> clone(const std::vector<T*>& container) {
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
    std::vector<T*> copy;
    copy.reserve(container.size());
    for (auto p : container) {
      copy.push_back(clone(p));
    }
    return copy;
  }

  IrContainer* container() const {
    return ir_container_;
  }

 protected:
  void registerClone(const Statement* src, Statement* clone);

  void handle(const Statement*) override;
  void handle(const Val*) override;
  void handle(const Expr*) override;

  void handle(const TensorDomain*) override;
  void handle(const TensorView*) override;
  void handle(const IterDomain*) override;

  void handle(const Bool*) override;
  void handle(const Double*) override;
  void handle(const Int*) override;
  void handle(const ComplexDouble*) override;
  void handle(const NamedScalar*) override;

  void handle(const ARangeOp*) override;
  void handle(const UnaryOp*) override;
  void handle(const BinaryOp*) override;
  void handle(const TernaryOp*) override;
  void handle(const RNGOp*) override;
  void handle(const BroadcastOp*) override;
  void handle(const ReductionOp*) override;
  void handle(const GroupedReductionOp*) override;
  void handle(const WelfordOp*) override;
  void handle(const LoadStoreOp*) override;
  void handle(const MmaOp*) override;
  void handle(const TransposeOp*) override;
  void handle(const ExpandOp*) override;
  void handle(const ShiftOp*) override;
  void handle(const GatherOp*) override;
  void handle(const ViewAsScalar*) override;
  void handle(const ViewOp*) override;

  void handle(const Split*) override;
  void handle(const Merge*) override;
  void handle(const Swizzle2D*) override;

 protected:
  // We keep track of the original -> clone map so we don't
  // duplicate clones of the same object if referenced multiple times
  std::unordered_map<const Statement*, Statement*> clones_map_;

 private:
  // The destination Fusion container
  IrContainer* ir_container_ = nullptr;

  // The dispatch interface doesn't allow returning values from
  // individual `handle()` methods, so they are storing the
  // result here
  Statement* clone_ = nullptr;

  // Builder to make all the new nodes
  IrBuilder builder_;
};

// Replicates all expressions used to generate the provided TensorView. Does not
// replicate inputs. Does not replicate scalar values. In other words the value
// provided will be recomputed from the inputs of the fusion.
class RecomputeTv : private IrCloner {
 public:
  // Replicates expressions and values in provided expressions.
  static TensorView* recompute(TensorView* tv);

 private:
  RecomputeTv(Fusion* fusion, std::vector<Expr*> exprs);

  void handle(const TensorDomain*) final;

  Fusion* fusion_;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch