1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <deque>
#include <unordered_map>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
class TensorDomain;
class TensorView;
// We're going to keep data related to the computeAt pass for each TensorView in
// this structure, this will allow us to keep a single entry in a map from a
// TensorView to this one.
class ComputeAtData {
public:
ComputeAtData() = default;
ComputeAtData(TensorView* tv);
// Clear after a given traversal. There will be more than one.
void clearPass();
// Makes sure value matches current_traversal_position if
// current_traversal_position_set is true. If this is not the case we're in
// an invalid compute_at that would require tensor replication.
void setPassPosition(unsigned int pos);
// Returns if new postion is greater or equal to previous seen, if
bool shouldSetComputeAt(unsigned int pos) const {
return pos > original_compute_at_position &&
pos > new_compute_at_position && pos >= current_traversal_position;
}
// Will return new_compute_at_position, after making sure we cleared out the
// last pass
unsigned int getNewPosition() const;
// Will make sure we haven't invalidated previous computeAt calls by
// checking that any axes previously in computeAt are still there.
void validateNewComputeAt() const;
// Did we ever compute a value for this TV?
bool touched() const {
return touched_;
}
TensorDomain* getOriginalDomain() const {
return original_domain_;
}
// If we set computeAt, save the domain so we can reset it after traversal.
// Traversal state can deviate from the domain we will want to save after the
// entire computeAt pass.
void setComputeAtDomain(TensorDomain* td);
// Return domain set in setComputeAtDomain
TensorDomain* getComputeAtDomain() const {
return new_compute_at_domain_;
}
private:
// Was the position ever modified?
bool touched_ = false;
// Hold onto the provided TensorView
TensorView* tv_ref_ = nullptr;
// Did this tv have computeAt set before calling this computeAt pass?
bool original_has_compute_at_ = false;
// What was the computeAt position before the computeAt pass started
unsigned int original_compute_at_position = 0;
// and what was the previous domain that position was set relative to.
TensorDomain* original_domain_ = nullptr;
// Position we can update during a traversal
unsigned int current_traversal_position = 0;
// Did this traversal set a position or not yet
bool current_traversal_position_set = false;
// Position to update after a traversal
unsigned int new_compute_at_position = 0;
// Domain when we actually set computeAt, will set back to this after the
// pass.
TensorDomain* new_compute_at_domain_;
};
class ComputeAt {
public:
static void run(
TensorView* _producer,
TensorView* _consumer,
unsigned int _consumer_position);
private:
TensorView* producer_;
TensorView* consumer_;
unsigned int consumer_position_;
// Runs replayPasC and sets producer computeAt settings. Returns
// producer_compute_at_axis.
unsigned int backwardComputeAt_impl(
TensorView* producer,
TensorView* consumer,
unsigned int consumer_compute_at_axis);
// Runs replayCasP and sets producer computeAt settings. Returns
// consumer_compute_at_axis.
unsigned int forwardComputeAt_impl(
TensorView* producer,
TensorView* consumer,
unsigned int producer_compute_at_axis);
// Look through all the use chains of producer. Check if there's a single
// consumer for all chains at or after the consumer specified in the computeAt
// call.
void setCommonConsumer();
// Propagate backward from consumer to producer, check if it increase
// computeAt position on tensors, if so take it!
void traverseBackward();
// Traverse from producer to common_consumer if it exists or through all uses
// of producer
void traverseForward();
// Run the computeAt pass
void runPass();
// Set outputs relative to eachother if there is not a common consumer
void setupOutputs();
// Common consumer if it exists
TensorView* common_consumer_ = nullptr;
// Producer use chains set in, used in a few spots.
std::deque<std::deque<TensorView*>> producer_use_chains_;
// All we need to know and keep track of for each TensorView in this pass.
std::unordered_map<TensorView*, ComputeAtData> tv_data;
ComputeAt(
TensorView* _producer,
TensorView* _consumer,
unsigned int _consumer_position);
ComputeAt() = delete;
~ComputeAt() = default;
ComputeAt(ComputeAt&) = delete;
ComputeAt& operator=(const ComputeAt& other) = delete;
};
} // namespace fuser
} // namespace jit
} // namespace torch
|