File: compute_at.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (163 lines) | stat: -rw-r--r-- 4,849 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#pragma once

#include <c10/util/Exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>

#include <deque>
#include <unordered_map>
#include <vector>

namespace torch {
namespace jit {
namespace fuser {

class TensorDomain;
class TensorView;

// We're going to keep data related to the computeAt pass for each TensorView in
// this structure, this will allow us to keep a single entry in a map from a
// TensorView to this one.
class ComputeAtData {
 public:
  ComputeAtData() = default;
  ComputeAtData(TensorView* tv);

  // Clear after a given traversal. There will be more than one.
  void clearPass();

  // Makes sure value matches current_traversal_position if
  // current_traversal_position_set is true. If this is not the case we're in
  // an invalid compute_at that would require tensor replication.
  void setPassPosition(unsigned int pos);

  // Returns if new postion is greater or equal to previous seen, if
  bool shouldSetComputeAt(unsigned int pos) const {
    return pos > original_compute_at_position &&
        pos > new_compute_at_position && pos >= current_traversal_position;
  }

  // Will return new_compute_at_position, after making sure we cleared out the
  // last pass
  unsigned int getNewPosition() const;

  // Will make sure we haven't invalidated previous computeAt calls by
  // checking that any axes previously in computeAt are still there.
  void validateNewComputeAt() const;

  // Did we ever compute a value for this TV?
  bool touched() const {
    return touched_;
  }

  TensorDomain* getOriginalDomain() const {
    return original_domain_;
  }

  // If we set computeAt, save the domain so we can reset it after traversal.
  // Traversal state can deviate from the domain we will want to save after the
  // entire computeAt pass.
  void setComputeAtDomain(TensorDomain* td);

  // Return domain set in setComputeAtDomain
  TensorDomain* getComputeAtDomain() const {
    return new_compute_at_domain_;
  }

 private:
  // Was the position ever modified?
  bool touched_ = false;

  // Hold onto the provided TensorView
  TensorView* tv_ref_ = nullptr;

  // Did this tv have computeAt set before calling this computeAt pass?
  bool original_has_compute_at_ = false;

  // What was the computeAt position before the computeAt pass started
  unsigned int original_compute_at_position = 0;

  // and what was the previous domain that position was set relative to.
  TensorDomain* original_domain_ = nullptr;

  // Position we can update during a traversal
  unsigned int current_traversal_position = 0;

  // Did this traversal set a position or not yet
  bool current_traversal_position_set = false;

  // Position to update after a traversal
  unsigned int new_compute_at_position = 0;

  // Domain when we actually set computeAt, will set back to this after the
  // pass.
  TensorDomain* new_compute_at_domain_;
};

class ComputeAt {
 public:
  static void run(
      TensorView* _producer,
      TensorView* _consumer,
      unsigned int _consumer_position);

 private:
  TensorView* producer_;
  TensorView* consumer_;
  unsigned int consumer_position_;

  // Runs replayPasC and sets producer computeAt settings. Returns
  // producer_compute_at_axis.
  unsigned int backwardComputeAt_impl(
      TensorView* producer,
      TensorView* consumer,
      unsigned int consumer_compute_at_axis);

  // Runs replayCasP and sets producer computeAt settings. Returns
  // consumer_compute_at_axis.
  unsigned int forwardComputeAt_impl(
      TensorView* producer,
      TensorView* consumer,
      unsigned int producer_compute_at_axis);

  // Look through all the use chains of producer. Check if there's a single
  // consumer for all chains at or after the consumer specified in the computeAt
  // call.
  void setCommonConsumer();

  // Propagate backward from consumer to producer, check if it increase
  // computeAt position on tensors, if so take it!
  void traverseBackward();

  // Traverse from producer to common_consumer if it exists or through all uses
  // of producer
  void traverseForward();

  // Run the computeAt pass
  void runPass();

  // Set outputs relative to eachother if there is not a common consumer
  void setupOutputs();

  // Common consumer if it exists
  TensorView* common_consumer_ = nullptr;

  // Producer use chains set in, used in a few spots.
  std::deque<std::deque<TensorView*>> producer_use_chains_;

  // All we need to know and keep track of for each TensorView in this pass.
  std::unordered_map<TensorView*, ComputeAtData> tv_data;

  ComputeAt(
      TensorView* _producer,
      TensorView* _consumer,
      unsigned int _consumer_position);

  ComputeAt() = delete;
  ~ComputeAt() = default;
  ComputeAt(ComputeAt&) = delete;
  ComputeAt& operator=(const ComputeAt& other) = delete;
};

} // namespace fuser
} // namespace jit
} // namespace torch