File: non_divisible_split.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (169 lines) | stat: -rw-r--r-- 5,349 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/non_divisible_split.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

void NonDivisibleSplitInfo::build(Fusion* fusion) {
  const auto vals = fusion->usedMathVals();
  auto tvs = ir_utils::filterByType<TensorView>(vals);

  // Find all non-divisible splits
  for (auto tv : tvs) {
    if (tv->isFusionInput()) {
      continue;
    }
    const std::vector<Val*> domain_vals(
        tv->domain()->domain().begin(), tv->domain()->domain().end());
    current_tv_ = tv;
    clearReachability();
    traverseFrom(fusion, domain_vals);
    current_tv_ = nullptr;
  }

  if (GpuLower::current() != nullptr) {
    removeRedundancy();
  }
}

void NonDivisibleSplitInfo::handle(Split* split) {
  if (split->in()->isBroadcast()) {
    return;
  }

  // Indicates if this split is going to be either predicated or
  // validated at run time
  bool is_protected = false;

  if (isReachableFromInnerDomains(split->in())) {
    // check if this split may be non-divisible
    auto maybe_non_divisible_extent = getMaybeNonDivisibleExtent(split);
    if (maybe_non_divisible_extent) {
      // If the outputs are vectorized, predication isn't
      // sufficient, it must be divisible.
      TORCH_INTERNAL_ASSERT(
          split->outer()->getParallelType() != ParallelType::Vectorize);
      if (split->inner()->getParallelType() == ParallelType::Vectorize) {
        splits_to_validate_.insert(split);
      } else {
        // Not proven to be a divisible split
        splits_to_predicate_[current_tv_].push_back(split);
      }

      is_protected = true;
    }
  }

  propagateReachability(split, is_protected);
}

bool NonDivisibleSplitInfo::isReachableFromInnerDomains(IterDomain* id) const {
  return inner_domains_.find(id) != inner_domains_.end();
}

void NonDivisibleSplitInfo::clearReachability() {
  inner_domains_.clear();
}

void NonDivisibleSplitInfo::propagateReachability(
    Split* split,
    bool is_protected) {
  // Propagate down the reachability information. Descendants of the
  // inner domain must be tracked.
  inner_domains_.insert(split->inner());

  // If this split itself is reachable, propagate the reachability to
  // the outer output as well. However, if this split is protected,
  // i.e., either predicated or validated, any potential effect by
  // descendants of the outer domain is taken care by the predicate or
  // run-time check of this split, so checking outer descendants isn't
  // required.
  if (isReachableFromInnerDomains(split->in()) && !is_protected) {
    inner_domains_.insert(split->outer());
  }
}

Val* NonDivisibleSplitInfo::getMaybeNonDivisibleExtent(Split* split) const {
  ExpressionEvaluator ee(split->fusion());
  auto in_extent = ee.evaluate(split->in()->extent());
  auto factor = ee.evaluate(split->factor());

  if (in_extent.has_value() && factor.has_value() &&
      in_extent.value() % factor.value() == 0) {
    return nullptr;
  }

  // even if the extent size is unknown, if the factor is known to
  // be 1, it's always divisible
  if (factor.has_value() && factor.value() == 1) {
    return nullptr;
  }

  auto ceildiv_dom = split->innerSplit() ? split->outer() : split->inner();
  return ceildiv_dom->extent();
}

void NonDivisibleSplitInfo::handle(Merge* merge) {
  propagateReachability(merge);
}

void NonDivisibleSplitInfo::propagateReachability(Merge* merge) {
  // Inner input index never exceeds its extent as it's computed as an
  // remainder. Outer may do.
  if (isReachableFromInnerDomains(merge->outer())) {
    inner_domains_.insert(merge->out());
  }
}

void NonDivisibleSplitInfo::removeRedundancy() {
  auto gpu_lower = GpuLower::current();
  TORCH_INTERNAL_ASSERT(gpu_lower != nullptr);

  std::unordered_set<IterDomain*> split_to_validate_outer;
  for (auto it = splits_to_validate_.begin();
       it != splits_to_validate_.end();) {
    auto outer_concrete = gpu_lower->caMap()->getConcreteMappedID(
        (*it)->outer(), IdMappingMode::EXACT);
    auto new_domain = split_to_validate_outer.insert(outer_concrete).second;
    if (!new_domain) {
      it = splits_to_validate_.erase(it);
    } else {
      ++it;
    }
  }

  // If validated by runtime checks, no need to predicate
  for (auto& kv : splits_to_predicate_) {
    auto& splits = kv.second;
    for (auto it = splits.begin(); it != splits.end();) {
      // If the outer domain is mapped with the outer domain of any
      // validated domain, it is safe to omit the predicate for the
      // split.
      Split* split_to_predicate = *it;
      if (std::any_of(
              splits_to_validate_.begin(),
              splits_to_validate_.end(),
              [&](Split* split_to_validate) {
                return gpu_lower->caMap()->areMapped(
                    split_to_validate->outer(),
                    split_to_predicate->outer(),
                    IdMappingMode::EXACT);
              })) {
        it = splits.erase(it);
      } else {
        ++it;
      }
    }
  }
}

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch