File: partial_split_map.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (55 lines) | stat: -rw-r--r-- 1,645 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/partial_split_map.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

void PartialSplitMap::build(Fusion* fusion) {
  auto used_vals = ir_utils::allTvs(fusion);

  for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
    auto exprs = StmtSort::getExprs(
        fusion, {tv->domain()->domain().begin(), tv->domain()->domain().end()});
    for (auto split : ir_utils::filterByType<Split>(exprs)) {
      // Only needs to check root domains as partial split is only
      // allowed with root domains
      if (std::find(
              tv->getRootDomain().begin(),
              tv->getRootDomain().end(),
              split->in()) == tv->getRootDomain().end()) {
        continue;
      }
      auto root_domain = split->in();
      auto start_offset = split->startOffset();
      start_offset_map_.insert({root_domain, start_offset});
      auto stop_offset = split->stopOffset();
      stop_offset_map_.insert({root_domain, stop_offset});
    }
  }
}

Val* PartialSplitMap::getStartOffset(IterDomain* root_domain) const {
  auto it = start_offset_map_.find(root_domain);
  if (it == start_offset_map_.end()) {
    return nullptr;
  } else {
    return it->second;
  }
}

Val* PartialSplitMap::getStopOffset(IterDomain* root_domain) const {
  auto it = stop_offset_map_.find(root_domain);
  if (it == stop_offset_map_.end()) {
    return nullptr;
  } else {
    return it->second;
  }
}

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch