1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
void ConcretizedBroadcastDomains::build(Fusion* fusion) {
exact_map_ = std::make_unique<ExactRootDomainMap>(fusion);
// Initialize the origin map with input broadcast domains
auto inputs = fusion->inputsAndCreated();
for (const auto fusion_input_tv :
ir_utils::filterByType<TensorView>(inputs)) {
for (auto root_id : fusion_input_tv->getRootDomain()) {
if (root_id->isBroadcast()) {
broadcast_origin_map_.emplace(
root_id, std::unordered_set<IterDomain*>({root_id}));
}
}
}
traverse(fusion);
}
bool ConcretizedBroadcastDomains::isConcretized(IterDomain* id) const {
auto it = broadcast_to_concrete_map_.find(id);
return it != broadcast_to_concrete_map_.end();
}
bool ConcretizedBroadcastDomains::isUniquelyConcretized(IterDomain* id) const {
auto it = broadcast_to_concrete_map_.find(id);
return it != broadcast_to_concrete_map_.end() && it->second.size() == 1;
}
bool ConcretizedBroadcastDomains::maybeNonUniquelyConcretized(
IterDomain* id) const {
auto it = broadcast_to_concrete_map_.find(id);
return it != broadcast_to_concrete_map_.end() && it->second.size() > 1;
}
void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) {
// Create a new entry for each of new broadcast domains
auto out = bop->out()->as<TensorView>();
for (const auto i : c10::irange(out->getRootDomain().size())) {
if (bop->getBroadcastDimFlags().at(i)) {
auto new_bcast_id = out->getRootDomain().at(i);
broadcast_origin_map_.emplace(
new_bcast_id, std::unordered_set<IterDomain*>({new_bcast_id}));
}
}
}
void ConcretizedBroadcastDomains::handle(Expr* expr) {
IterVisitor::handle(expr);
// Propagate broadcast origin info from producers to consumers
for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
std::unordered_set<IterDomain*> producer_broadcasts;
// This assumes there's no merged broadcast axes between root and rfactor
// domains which is not possible at the moment. If this assumption is ever
// invalidated we would need to manaually propagate root IDs to rfactor IDs.
for (auto producer_id : producer->getMaybeRFactorDomain()) {
if (producer_id->isBroadcast()) {
producer_broadcasts.insert(producer_id);
}
}
if (producer_broadcasts.empty()) {
continue;
}
for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
auto p2c_map =
PairwiseRootDomainMap(producer, consumer)
.mapProducerToConsumer(
producer->domain(), consumer->domain(), producer_broadcasts);
for (const auto& kv : p2c_map) {
auto p_id = kv.first;
auto c_id = kv.second;
// If the consumer ID is a reduction (i.e., a trivial
// reduction), do not consider it's concretized.
const bool is_concretized =
!c_id->isBroadcast() && !c_id->isReduction();
auto it = broadcast_origin_map_.find(p_id);
TORCH_INTERNAL_ASSERT(
it != broadcast_origin_map_.end(),
"Broadcast origin info not found for producer broadcast domain: ",
p_id->toString(),
" of ",
producer->toString());
const auto& producer_origins = it->second;
if (is_concretized) {
// Keep track of all the origin domains as concretized
for (auto origin : producer_origins) {
markAsConcretized(origin, c_id);
}
} else {
// Not concretized yet. Propagate forward the origin info.
auto& consumer_origins = broadcast_origin_map_[c_id];
for (auto origin : producer_origins) {
consumer_origins.insert(origin);
}
consumer_origins.insert(c_id);
}
}
}
}
}
void ConcretizedBroadcastDomains::markAsConcretized(
IterDomain* broadcast_root_domain,
IterDomain* concrete_root_domain) {
std::deque<IterDomain*> child_domains({broadcast_root_domain});
while (!child_domains.empty()) {
auto child = child_domains.front();
child_domains.pop_front();
auto& concrete_ids = broadcast_to_concrete_map_[child];
auto inserted =
insertRootDomainToConcreteDomainSet(concrete_root_domain, concrete_ids);
if (!inserted) {
continue;
}
const auto& child_uses = child->uses();
for (auto child_use : child_uses) {
for (auto out_id :
ir_utils::filterByType<IterDomain>(child_use->outputs())) {
child_domains.push_back(out_id);
}
}
}
}
bool ConcretizedBroadcastDomains::insertRootDomainToConcreteDomainSet(
IterDomain* new_root_id,
std::unordered_set<IterDomain*>& id_set) {
auto has_exactly_mapped_id =
std::any_of(id_set.begin(), id_set.end(), [&](IterDomain* existing_id) {
return exact_map_->areMapped(new_root_id, existing_id);
});
if (has_exactly_mapped_id) {
return false;
} else {
id_set.emplace(new_root_id);
return true;
}
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|