File: transform_iter.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (201 lines) | stat: -rw-r--r-- 6,904 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#pragma once

#include <torch/csrc/WindowsTorchApiMacro.h>

#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <unordered_map>
#include <vector>

namespace torch {
namespace jit {
namespace fuser {

namespace {

// Enable pair<IterDomain*, size_t> in a set, size_t must be unique in set
struct id_int_lt {
  bool operator()(
      const std::pair<IterDomain*, size_t>& first,
      const std::pair<IterDomain*, size_t>& second) const {
    return first.second < second.second;
  }
};

} // namespace

// Uses the history of _target_domain, and replays that history using the
// provided map.
//
// target_domain contains the history we want replayed.
//
// id_map maps IterDomains in that history to the IterDomains we want it
// replayed on.
//
// error_on_failure = true will cause the replay to error if we can't replay any
// operation in target_domain's history due to missing IDs in the id_map.
//
// If error_on_failure = false, replay will replay everything it can, and ignore
// operations it can't.
class TORCH_CUDA_API ReplayTransformations : public IterVisitor {
 protected:
  const std::vector<IterDomain*>& target_domain_;
  std::unordered_map<IterDomain*, IterDomain*> id_map_;
  std::unordered_map<IterDomain*, size_t> leaf_ids_;
  std::vector<IterDomain*> leaf_vec_;
  size_t counter = 0;
  bool error_on_failure_ = true;
  bool ran_replay = false; // Mark if replay has been run
  using IterVisitor::handle;

  // Transform dispatch
  void handle(Expr* e) override;

  // We're going to replay this split operation on the corresponding ID
  void handle(Split* s) override;

  // We're going to replay this merge operation on the corresponding IDs
  void handle(Merge* m) override;

 public:
  ReplayTransformations(
      const std::vector<IterDomain*>& _target_domain,
      std::unordered_map<IterDomain*, IterDomain*> _id_map,
      bool _error_on_failure = true);

  // Replays outputs that were generated from ids.first on ids.second
  void runReplay();

  // Returns map from provided target domain to their corresponding IDs
  const std::unordered_map<IterDomain*, IterDomain*>& getReplay() {
    if (!ran_replay)
      runReplay();
    return id_map_;
  }

  // Returns leaf_ids_ the size_t marks the order in which they were put into
  // the map, this is part of the structure because it's used to generate the
  // order from 'getLeafIDs'
  const std::unordered_map<IterDomain*, size_t>& getUnorderedLeafIDs() {
    if (!ran_replay)
      runReplay();
    return leaf_ids_;
  }

  // Returns all terminating IDs that resulted from the replay. Leaf IDs are run
  // to run deterministic, but otherwise in no specific order.
  const std::vector<IterDomain*>& getLeafIDs() {
    if (!ran_replay)
      runReplay();
    return leaf_vec_;
  }
};

/*
 * Motivation:
 *
 * Consider the following program:
 *
 * T1[I0, R1] = T0[I0, I1]
 * T2[I0] = T1[I0, R1i]
 *
 * T1->split(1, factor)
 * T1->rFactor(2)
 *
 * T4[I0, R1orf, I1irf] = T0[I0, I1]
 * T1[I0, R1i] = T4[I0, R1orf, I1irf]
 * T2[I0] = T1[I0, R1i]
 *
 * There's an issue when we call replayCasP on
 * T4[I0, R1o, I1i] = T0[I0, I1]
 *
 * This would try to replay T4 as T0, and it could include the rfactor domains.
 * For example we compute T0 inline with T4. The way computeAt is setup this
 * would call replayPasC(T0, T4, -1) then repalyCasP(T4, T0, -1)
 *
 * We might assume that the only way we will hit this is if we call
 * T4->computeAt(T0...) so it might be safe to assume that the right
 * transformations would be replayed. However, we want to preserve the rfactor
 * domain, so since it would replay T4 at root, it would produce iterdomains
 * that wouldn't corresopnd to those in rfactor. Also, I don't know if this
 * assumption is correct.
 *
 * Therefore, we will assume it is not correct, and we will validate here that
 * if we replay a domain that it would transform it in a way consistent with
 * any defined RFactor domains, then we will update the replay map so that
 * RFactor roots are mapped to intermediate IterDomains  in the target and start
 * replay from there.
 *
 *
 * SHORT DESCRIPTION:
 *
 * This class will validate/do the above. It will also run through
 * transformations in target according to replay_map. If equal transformations
 * already exist in replay_domain history, we will not redo those
 * transformations, but instead update replay_map to reflect forwarding the
 * existing transformations. This later part is the "best effort" replay. Though
 * we include rfactor replay and validation here.
 *
 * Given an Expr in target_domain, check if its inputs are in replay_map. If so,
 * check if the mapped domain in replay_map are recorded to be transformed by an
 * equivelent operation in replay_domain's history. If so, "forward" the
 * operation and update replay_map to the outputs of target_domain's output(s),
 * to the output of the equivlent expr's outputs in relpay_domain's history.
 *
 * replay_map maps root IDs in the history of target_domain to root IDs in the
 * history replay_domain
 */

class TORCH_CUDA_API BestEffortReplay {
 private:
  std::unordered_map<IterDomain*, IterDomain*> id_map_;
  std::unordered_map<IterDomain*, size_t> leaf_ids_;
  size_t counter = 0;

 public:
  // replay_map: mapping of target root domains to corresponding
  // replay root domains
  BestEffortReplay(
      const std::vector<IterDomain*>& replay_domain,
      const std::vector<IterDomain*>& target_domain,
      std::unordered_map<IterDomain*, IterDomain*> replay_map,
      bool forward_bcast_mismatch = false);

  // Return iter domain map from target_domain IDs to their "replayed"
  // replay_domain IDs. If not in map, was not replayed.
  const std::unordered_map<IterDomain*, IterDomain*>& getReplay() const {
    return id_map_;
  }

  // ids in replay that did not have matching transforms in target_domain
  const std::unordered_map<IterDomain*, size_t>& getUnorderedLeafIDs() {
    return leaf_ids_;
  }

  // Returned ordered set of IDs in getUnorderedLeafIDs
  std::vector<IterDomain*> getLeafIDs() {
    std::set<std::pair<IterDomain*, size_t>, id_int_lt> ordered_set;
    for (auto entry : leaf_ids_)
      ordered_set.emplace(entry);

    std::vector<IterDomain*> leaf_vec_;
    leaf_vec_.resize(ordered_set.size());
    std::transform(
        ordered_set.begin(),
        ordered_set.end(),
        leaf_vec_.begin(),
        [](std::pair<IterDomain*, size_t> entry) { return entry.first; });
    return leaf_vec_;
  }

  // Find the first position i where td1[i] is not the same as td2[i]. "Same"
  // means the DAG and input IDs to generate td1[i] and td2[i] are the same.
  static int findFirstMismatchedID(
      const TensorDomain* td1,
      const TensorDomain* td2);
};

} // namespace fuser
} // namespace jit
} // namespace torch