File: lower_thread_predicate.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (145 lines) | stat: -rw-r--r-- 5,533 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

#pragma once

#include <c10/macros/Export.h>

#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/parallel_type_bitmap.h>

#include <unordered_map>
#include <unordered_set>
#include <utility>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

//! Maps TensorViews to a { ParallelTypeBitmap, SourceMap } pair
//!
//! Map from TensorView to bit set represnting <BIDx, BIDy, BIDz, TIDx, TIDy,
//! TIDz> If any dependency of TV had a parallelized reduction, we will track
//! it here. This will be used for predicate generation to prevent
//! parallelization on that axis. This is important if we have a reduction on
//! for example TIDx, as the reduced value is only valid on threadIdx.x == 0
//! therefore if we use that value later in the kernel we have that predicate.
//! If we follow a reduction parallelized on TIDx with a broadcast on TIDx we
//! no longer need the predicate and can reset the bit accordingly
//!
//! In addition, if a parallel thread type is not used, it is
//! redundant to use all threads/blocks. That isn't a problem
//! generally although it can be inefficient, but when an aliased smem
//! buffer is used as an output, redundant writes can be invalid (see issue
//! #1110). PredicateInfo::redundant_types track which parallel types
//! are redundant for each tensor and is used to let only one
//! thread/block of a redundant type execute the expression for a
//! tensor.
class TORCH_CUDA_CU_API ThreadPredicateMap {
 public:
  using SourceMap = std::unordered_map<
      ParallelType,
      std::unordered_set<const TensorView*>,
      TypeHash>;

  //! Thread predicate information for each tensor
  struct PredicateInfo {
    // Parallel types where only one thread/block is valid.
    ParallelTypeBitmap limited_types;
    // Parallel types where only one thread/block is enough.
    ParallelTypeBitmap redundant_types;
    // Tracking use chain of redundant writes:
    //  [Redundant use chain]
    //  a parallel type is a `redundant_consumer_type` only
    //    if all of its propagation use chains terminate with
    //    a redundant write of this type.
    //  A propagation use chain is currently either a reg-to-reg
    //   chain for a shared mem tv, or a reg/smem-to-reg/smem chain
    //   for a global tv.
    // This is complementary information to `redundant_types`.
    //  If a tensor view is redundantly written and not redundantly
    //  used by all consumers, see FusionRedundantPredSync3,
    //  a RAW sync will need to be inserted before reading
    //  this redundantly written tensor.
    ParallelTypeBitmap redundant_use_types;
    bool operator==(const PredicateInfo& other) const {
      return limited_types == other.limited_types &&
          redundant_types == other.redundant_types &&
          redundant_use_types == other.redundant_use_types;
    }
  };

  using MapType = std::unordered_map<const TensorView*, PredicateInfo>;

  using const_iterator = MapType::const_iterator;

  //! Build a map from each tensor to PredicateInfo.
  void build(Fusion* fusion);

  //! Get a PredicateInfo for a given tensor. If it's an output of
  //! a parallel broadcast, unmask the limited_types_ bit of the
  //! corresponding parallel type since it must join the broadcast
  //! operation although the valid input is only available at one of
  //! the threads/blocks.
  PredicateInfo getPredicateInfo(const TensorView* tv) const;

  //! Returns a flag set that indicates which parallel types should be
  //! predicated.
  ParallelTypeBitmap getPredicatedParallelTypes(const TensorView* tv) const;

  //! Returns a Bool predicate for a given TensorView.
  Bool* getPredicate(const TensorView* tv) const;

  //! Returns a ParallelTypeBitmap representing which domain needs
  //! blockBroadcast.
  //!
  //! Even when a domain is broadcast and parallelized, it does not need
  //! blockBroadcast unless it is predicated by limited_types_
  ParallelTypeBitmap getParallelBroadcastDomains(const TensorView* tv) const;

  //! Mark tv as updated so that rebuilding the map should recompute
  //! its predicates and those of its dependents.
  void markAsUpdated(const TensorView* tv);

  void print() const;

  //! Generate a Bool value from PredicateInfo.
  static Bool* getPredicateFromPredicateInfo(
      const ThreadPredicateMap::PredicateInfo& pred_info);

  //! Get the redundant use types of the given expr, see [Redundant use chain]
  ParallelTypeBitmap getRedundantConsumerType(Expr* expr) const;

 private:
  // Update the thread_predicates bitset based on provided Expr
  void updateBitSet(const Expr*);

  const_iterator find(const TensorView* tv) const;
  const_iterator end() const;

  const PredicateInfo& at(const TensorView* tv) const;
  PredicateInfo& at(const TensorView* tv);

  //! Update a mapping
  bool update(
      const TensorView* tv,
      const ParallelTypeBitmap& limited_types,
      const ParallelTypeBitmap& redundant_types);

  //! Update a mapping
  bool update(const TensorView* tv, const PredicateInfo& pred_and_src);

  //! Backward populate redundant use chain info once the redundant
  //!  parallel writes have been identified.
  void populateRedundantUseMap(Fusion* fusion);

 private:
  MapType thread_predicates_;
  //! Keep track of updated tensors that need predicates to be computed
  std::unordered_set<const TensorView*> updated_tvs_;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch