File: inline_propagator.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (118 lines) | stat: -rw-r--r-- 4,276 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#pragma once

#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/maxinfo_propagator.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>

#include <unordered_set>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class TORCH_CUDA_CU_API MaxPosCalculator {
  ComputeAtMode mode_ = ComputeAtMode::Standard;

  // Root domains in producer that's unmappable to any of its consumers
  std::unordered_set<IterDomain*> unmappable_dims_;

  // User set IterDomains to not inline, used in schedulers to avoid inlining
  // trivial reductions
  std::unordered_set<IterDomain*> uninlinable_ids_;

  // Iterate through all TVs and collect the dimensions of each TV that don't
  // map to all its consumer TVs.
  void buildUnmappableDims();

  // Utility function to return if an id of tv is a valid iter domain to inline
  // within. This is used in getMaxPos{PasC,CasP}. Different variations of the
  // bool values are used if checking max position of PasC, CasP, or checking
  // for a max "self" position.
  bool isAllowedID(
      IterDomain* id,
      TensorView* tv,
      bool allow_reduction,
      bool allow_vectorize,
      bool allow_unmappable) const;

 public:
  // Returns the position at which tv can be inlined within.
  size_t getMaxPosSelf(
      TensorView* tv,
      bool allow_reduction,
      bool allow_vectorize,
      bool allow_unmappable) const;

  // Returns the maximum position producer can be inlined based on consumer
  // given the set ComputeAtMode
  size_t getMaxProducerPosFromConsumer(
      TensorView* producer,
      TensorView* consumer) const;

  MaxPosCalculator(
      ComputeAtMode mode,
      std::unordered_set<IterDomain*> uninlinable_ids = {});
};

// Propagate inline position to the `selected` tensors in the DAG. If `selected`
// is not specified or empty, then propagate to the entire DAG.
class TORCH_CUDA_CU_API InlinePropagator
    : public MaxInfoSpanningTree::Propagator {
  // Checks producers and consumers to see what the maximum position in tv is
  // that can be shared across both directions.
  size_t getMaxPosAll(TensorView* tv, bool check_siblings = true);

  // We use mapped_reference_pos_ to keep track of the outer axes information of
  // the reference tensor. That is, mapped_reference_pos_[tv] answers the
  // question "What outer axes in tv are shared with the specified reference
  // tensor's outer axes?". However, when we actually set the CA position of tv,
  // we might not want to set it as mapped_reference_pos_[tv] because because we
  // don't want to inline certain things (such as vectorized dimensions, inner
  // most broadcasting, etc.).
  std::unordered_map<TensorView*, size_t> mapped_reference_pos_;

  // Actually set the computeAt position. This does not necessarily equal to
  // mapped_reference_pos_[tv] because we don't want to inline certain things.
  void setCAPos(TensorView* tv);

  const MaxPosCalculator max_pos_calc;
  std::unordered_set<TensorView*> selected_;
  std::unordered_set<TensorView*> needs_update_max_producer_;
  TensorView* reference_;
  size_t reference_pos_;
  ComputeAtMode mode_ = ComputeAtMode::Standard;

 public:
  InlinePropagator(
      TensorView* reference,
      int64_t reference_pos,
      ComputeAtMode mode = ComputeAtMode::Standard,
      std::unordered_set<TensorView*> selected = {},
      std::unordered_set<IterDomain*> uninlinable_ids = {});

  InlinePropagator(
      TensorView* reference,
      int64_t reference_pos,
      std::unordered_set<TensorView*> selected)
      : InlinePropagator(
            reference,
            reference_pos,
            ComputeAtMode::Standard,
            selected) {}

  ~InlinePropagator() = default;

  // Actually propagate the transformations for the inlining pass. Uses the
  // functions above to figure out what position to do the propagation at.
  virtual void setUp() override;
  virtual void propagateC2P(TensorView* from, TensorView* to) override;
  virtual void propagateP2C(TensorView* from, TensorView* to) override;
  virtual void propagateSibling(TensorView* from, TensorView* to) override;
  virtual void tearDown() override;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch