File: predicate_compute.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (198 lines) | stat: -rw-r--r-- 6,013 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#pragma once

#include <torch/csrc/jit/codegen/cuda/index_compute.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class PredicateCompute {
 public:
  // ignore_internal_syncthread_ops will prevent creation of predicates on
  // block/grid broadcast/reduce as these have syncthread calls within them
  // so all threads need to execute the function.
  static Bool* getInlinePredicate(
      const Expr* expr,
      const std::vector<kir::ForLoop*>& loops,
      Bool* thread_pred,
      PredicateType pred_type);
};

//! Parallelized domains may need to be predicated with threading
//! indices and IterDomain extents. For example, if a domain is
//! parallelized by TIDx, when TIDx is not exact, i.e., it can be
//! larger than the extents of domains parallelized by TIDx,
//! threadIdx.x may be larger than the IterDomain extent. This can be
//! harmless for Local tensors, however, for it would
//! result in out-of-bounds access for Shared tensors as they are
//! allocated based on tensor shapes rather than threading
//! dimensions.
class ParallelizedDomainPredicate {
 public:
  //! Predicate information for parallelized domains
  class PredicateInfo {
   public:
    explicit PredicateInfo(ParallelType pt) : pt_(pt) {}

    //! Adds a domain that is parallized by the same paralell type
    bool addDomain(IterDomain* id);

    const std::vector<IterDomain*>& ids() const {
      return ids_;
    }

    //! Generates a predicate Val from predicate information
    Bool* getPredicate() const;

   private:
    ParallelType pt_;
    //! Domains parallelized by the same parallel type
    std::vector<IterDomain*> ids_;
  };

  //! Returns a predicate Val for parallelied domains of an expression.
  static Bool* getPredicate(
      const Expr* expr,
      const std::vector<kir::ForLoop*>& loops);

  //! Returns predicate information for parallelied domains of an
  //! expression.
  static std::unordered_map<ParallelType, PredicateInfo, TypeHash>
  getPredicateMap(
      const Expr* expr,
      const std::vector<kir::ForLoop*>& loops,
      kir::ForLoop* unswitched_loop = nullptr);
};

//! Keys to identify unique unswitch predicates. Just consists of a
//! predicated concrete domain if not parallelized. If parallelized,
//! pick one for each different parallelization. When the same
//! parallel type is used for different concrete domains, they are
//! considered different predicates and are included in the unswitch
//! condition lists.
class UnswitchPredicateKey {
 public:
  UnswitchPredicateKey();

  UnswitchPredicateKey(
      IterDomain* predicated_consumer_id,
      TensorView* consumer_tv,
      IterDomain* predicated_concrete_id);

  bool operator==(const UnswitchPredicateKey& other) const {
    return predicated_concrete_id_ == other.predicated_concrete_id_ &&
        parallel_concrete_ids_ == other.parallel_concrete_ids_;
  }

  const auto& predicatedId() const {
    return predicated_concrete_id_;
  }

  const auto& parallelConcreteIds() const {
    return parallel_concrete_ids_;
  }

  IterDomain* parallelId(ParallelType pt) const {
    auto it = parallelConcreteIds().find(pt);
    if (it == parallelConcreteIds().end()) {
      return nullptr;
    } else {
      return it->second;
    }
  }

  std::string toString() const;

 private:
  //! Predicated concrete domain
  IterDomain* predicated_concrete_id_ = nullptr;
  //! Store parallelized concrete domains
  std::unordered_map<ParallelType, IterDomain*, TypeHash>
      parallel_concrete_ids_;
};

struct UnswitchPredicateKeyHash {
  std::size_t operator()(const UnswitchPredicateKey& key) const;
};

class TORCH_CUDA_CU_API UnswitchPredicate {
 public:
  static Bool* get(
      const std::vector<kir::ForLoop*>& outer_loops,
      kir::ForLoop* unrolled_loop);

 private:
  //! Predicate information for each UnswitchPredicateKey.
  struct MergedPredicates {
    //! Predicate information for the start and stop predicates.
    struct Info {
      //! Most restrictive static predicate. Nullptr if no static
      //! predicate found.
      Bool* static_pred = nullptr;
      //! The offset value of static_pred
      int64_t static_offset = 0;
      //! List of dynamic predicates.
      std::vector<Bool*> dynamic_preds;
    };
    UnswitchPredicateKey predicate_key;
    Info start;
    Info stop;
  };

  UnswitchPredicate(
      std::vector<kir::ForLoop*> outer_loops,
      kir::ForLoop* unrolled_loop);

  void predicateOn(Expr*);

  void openLoop(kir::ForLoop*);

  void openIte(kir::IfThenElse*);

  //! Generates the final predicates from the predicated_keys map
  void finalize();

  //! Merge predicates as much as possible. If a predicate offset is
  //! static, only pick the most restrictive one, e.g., the one with the
  //! minimum offset for the start predication.
  void mergeUnswitchPredicateOffsets(
      Bool* predicate,
      Val* offset,
      MergedPredicates::Info& merged_predicate_info,
      bool is_start);

  //! Adds new predicates for parallelized domains
  void addParallelizedDomainPredicates(Expr*);

 private:
  //! Track which iter domains have been predicated
  std::unordered_set<UnswitchPredicateKey, UnswitchPredicateKeyHash>
      predicated_keys_;

  //! The predicates that have been recorded but not yet finalized
  std::vector<MergedPredicates> pending_predicates_;

  //! Track which parallelized domains have been predicated
  std::unordered_map<
      ParallelType,
      ParallelizedDomainPredicate::PredicateInfo,
      TypeHash>
      parallelized_dom_predicates_;

  //! The predicates that have been generated.
  std::vector<Bool*> predicates_;

  std::vector<kir::ForLoop*> for_loops_;

  kir::ForLoop* unrolled_loop_;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch