1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
#pragma once
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <functional>
#include <unordered_map>
#include <vector>
// Hoisting common index subexpressions
//
// Class CommonIndexMap is updated during the lowering as new indices
// are inserted. An index is uniquely identified with CommonIndexKey,
// which consists of the concrete ID of the indexed/predicated domain,
// the for-loops used in the index, and the index vals of the use
// for-loops.
//
// Once all indices are inserted to CommonIndexMap, allocations of the
// the hoisted indices are inserted by allocateCommonIndices. Note
// that this assumes that the CUDA code generator does not inline a
// scalar Val with allocation (PR #1434).
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
//! Class to represent unique indexed domains for index
//! hoisting. Uniquenesss is determined with the indexed domain
//! itself, the for-loops and their index values.
class CommonIndexKey {
friend struct CommonIndexKeyHash;
public:
//! \param consumer_indexed_id Indexed consumer domain
//! \param consumer_td TensorDomain of consumer_indexed_id
//! \param ref_td Reference domain at the time of indexing
//! \param ref_index_map Index map of the reference domain
//! \param loops Loop structure where this id is indexed
CommonIndexKey(
IterDomain* consumer_indexed_id,
TensorDomain* consumer_td,
TensorDomain* ref_td,
const std::unordered_map<IterDomain*, Val*>& ref_index_map,
const std::vector<kir::ForLoop*>& loops);
//! \param consumer_indexed_id Indexed consumer domain
//! \param consumer_td TensorDomain of consumer_indexed_id
//! \param loop_domains Resolved vector of iterdomain corresponding to loops
//! \param loop_index_map Index mapping generated from the loop nest.
//! \param loops Loop structure where this id is indexed
//! Duplicate of above, but without a reference domain. TODO: Remove other
//! implementation.
CommonIndexKey(
IterDomain* consumer_indexed_id,
TensorDomain* consumer_td,
const std::vector<IterDomain*>& loop_domains,
const std::unordered_map<IterDomain*, Val*>& loop_index_map,
const std::vector<kir::ForLoop*>& loops);
const IterDomain* concreteIndexedId() const {
return concrete_indexed_id_;
}
const std::vector<kir::ForLoop*>& usedLoops() const {
return used_loops_;
}
const std::vector<Val*>& loopIndexVals() const {
return loop_index_vals_;
}
bool operator==(const CommonIndexKey& other) const;
std::string toString() const;
private:
//! Concrete domain of indexed domain
IterDomain* concrete_indexed_id_ = nullptr;
//! Loops used for the index
std::vector<kir::ForLoop*> used_loops_;
//! Loop index vals for the used loops
std::vector<Val*> loop_index_vals_;
};
struct CommonIndexKeyHash {
std::size_t operator()(const CommonIndexKey& key) const {
auto h = std::hash<const IterDomain*>{}(key.concrete_indexed_id_);
// NOTE: do not use other fields as the pointers can be different
// even when two keys can share the same index
return h;
}
};
//! Map to hold hoisted common indices
class TORCH_CUDA_CU_API CommonIndexMap {
public:
//! Register an indexd consumer domain to hoist
//!
//! Returns a corresponding hoisted index and a flag indicating if a
//! new index is inserted.
//!
//! Consumer domains are used even for producer indexing since
//! producer domains in producer indexing are temporary replay
//! domains.
std::pair<Val*, bool> insert(
IterDomain* indexed_consumer_id,
TensorDomain* consumer_td,
TensorDomain* ref_td,
const std::unordered_map<IterDomain*, Val*>& ref_index_map,
const std::vector<kir::ForLoop*>& loops,
Val* index);
//! Duplicate of above, but without a reference domain. TODO: Remove other
//! implementation.
std::pair<Val*, bool> insert(
IterDomain* indexed_consumer_id,
TensorDomain* consumer_td,
const std::vector<IterDomain*>& loop_domains,
const std::unordered_map<IterDomain*, Val*>& loop_index_map,
const std::vector<kir::ForLoop*>& loops,
Val* index);
const auto& commonIndexMap() const {
return common_index_map_;
}
const auto& useCounts() const {
return use_counts_;
}
private:
//! Utility method to insert a key into common index
//! map. Returns a pair of an IR node and a boolean value.
//! The IR node will be the previously inserted index if
//! the key found a match, or will be the original index
//! if this is new key and the key will be stored.
//! The boolean value will be true if the key is stored,
//! i.e. first time it is inserted.
std::pair<Val*, bool> tryInsertNewIndex(CommonIndexKey key, Val* index);
private:
//! Map to hold hoisted common indices
std::unordered_map<CommonIndexKey, Val*, CommonIndexKeyHash>
common_index_map_;
std::unordered_map<CommonIndexKey, int, CommonIndexKeyHash> use_counts_;
};
//! Insert allocations of hoisted indices. Must be called after
//! collecting all common indices.
std::vector<Expr*> allocateCommonIndices(const std::vector<Expr*>& exprs);
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|