1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
|
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/partial_split_map.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
void PartialSplitMap::build(Fusion* fusion) {
auto used_vals = ir_utils::allTvs(fusion);
for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
auto exprs = StmtSort::getExprs(
fusion, {tv->domain()->domain().begin(), tv->domain()->domain().end()});
for (auto split : ir_utils::filterByType<Split>(exprs)) {
// Only needs to check root domains as partial split is only
// allowed with root domains
if (std::find(
tv->getRootDomain().begin(),
tv->getRootDomain().end(),
split->in()) == tv->getRootDomain().end()) {
continue;
}
auto root_domain = split->in();
auto start_offset = split->startOffset();
start_offset_map_.insert({root_domain, start_offset});
auto stop_offset = split->stopOffset();
stop_offset_map_.insert({root_domain, stop_offset});
}
}
}
Val* PartialSplitMap::getStartOffset(IterDomain* root_domain) const {
auto it = start_offset_map_.find(root_domain);
if (it == start_offset_map_.end()) {
return nullptr;
} else {
return it->second;
}
}
Val* PartialSplitMap::getStopOffset(IterDomain* root_domain) const {
auto it = stop_offset_map_.find(root_domain);
if (it == stop_offset_map_.end()) {
return nullptr;
} else {
return it->second;
}
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|