File: predicate_compute.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (78 lines) | stat: -rw-r--r-- 2,328 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#pragma once

#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>

/*
 * Predicate compute takes a TensorView and set of indices. The number of
 * indices and the root of the TensorView are required to have the same number
 * of dimensions. Predicate compute should be run after index compute, and the
 * result of index compute should be used for the indices entry.
 *
 * A vector of Int values are returned which are the output of the operation
 * index[i] < get_root(TV)->domain()->axis(i)->size()
 *
 * It is assumed that no predicate is required if index[i] is an index directly
 * from a for loop. This will not catch all cases if we actually have static
 * size information for example:
 *
 * TV[I].split(4)
 * would produce the code:
 * for(i : I/4)
 *   for(j : 4)
 *     if( i * 4 + j < TV.size(0))
 *       TV[i * 4 + j]...
 *
 * However if we had TV.size[0] = 16 at "compile time" then we wouldn't need the
 * predicate. However we will still generate: for(i : 4) for(j : 4) if( i * 4 +
 * j < TV.size(0)) TV[i * 4 + j]...
 *
 */

namespace torch {
namespace jit {
namespace fuser {

class PredicateCompute {
 public:
  // Return the series of predicates, if an axis doesn't have a predicate
  // reutrns 1
  static std::vector<kir::Bool*> computePredicates(
      const TensorView* tv,
      const std::vector<Val*>& indices,
      bool use_rfactor);

  static kir::Bool* getInlinePredicate(
      Expr* expr,
      const std::vector<kir::ForLoop*>& loops,
      kir::Bool* thread_pred,
      bool ignore_block_grid_reductions = true);
};

class TORCH_CUDA_API UnrollPredicate {
 public:
  static kir::Bool* get(
      const std::vector<kir::ForLoop*>& outer_loops,
      kir::ForLoop* unrolled_loop,
      const std::unordered_map<IterDomain*, IterDomain*>& p2c_root_map);

 private:
  UnrollPredicate(
      std::vector<kir::ForLoop*> outer_loops,
      kir::ForLoop* unrolled_loop,
      const std::unordered_map<IterDomain*, IterDomain*>& _p2c_root_map);

  void predicateOn(Expr*);

  void openLoop(kir::ForLoop*);

 private:
  std::unordered_map<IterDomain*, kir::Bool*> predicates_;
  std::vector<kir::ForLoop*> for_loops_;

  const std::unordered_map<IterDomain*, IterDomain*>& p2c_root_map_;
};

} // namespace fuser
} // namespace jit
} // namespace torch