File: parallel_dimension_map.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (78 lines) | stat: -rw-r--r-- 2,482 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#pragma once

#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>

#include <deque>
#include <unordered_map>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

//! Maps TID/BID to its dimension. It is by default blockDim/gridDim,
//! but if use of a ParallelType is mapped to a unique constant
//! extent, the constant value is used instead since presumably it's
//! more efficient.
class TORCH_CUDA_CU_API ParallelDimensionMap {
 public:
  void build(Fusion* fusion);

  //! Returns the dimension of a ParallelType. nullptr is returned if
  //! a ParallelType is unused.
  Val* get(ParallelType pt) const;

  //! True if the dimension of a ParallelType is known to be exact
  bool isExact(ParallelType pt) const;

  std::string toString() const;

  //! Symbolically analyze if two extent vals are equal
  static bool equalDim(Val* dim1, Val* dim2);

 private:
  //! Register the extent of an IterDomain if its constant
  void registerConstantExtent(IterDomain* id);

  void handleParallelDomain(IterDomain* id);

  void populateDimensionMapWithSingleCASet(
      ParallelType pt,
      const std::unordered_set<IterDomain*>& dom_set);

  void populateDimensionMapWithMultipleCASet(
      ParallelType pt,
      const std::unordered_set<IterDomain*>& dom_set);

  //! TIDx may need to be marked as non-exact as it may be padded to a
  //! multiple of the warp size.
  void adjustMappingsForWarpPadding();

  static IterDomain* getCAMappedConcreteDomain(IterDomain* id);

 private:
  //! Maps from parallel types to dimensions, which are constant if
  //! a unique value is found.
  std::unordered_map<ParallelType, Val*, TypeHash> dim_map_;
  //! Set of parallel types whose dimensions are identified to be
  //! exactly the same as extents of mapped domains.
  std::unordered_set<ParallelType, TypeHash> exact_types_;

  // Below are temporary maps to build the ParallelType-to-dimension
  // map. Only used during build().

  //! Map from a parallel type to a set of concrete domains where the
  //! parallel type is used.
  std::unordered_map<ParallelType, std::unordered_set<IterDomain*>, TypeHash>
      concrete_dom_map_;
  //! Keep track of constant extents found for a CA domain set
  //! represented by the concrete domain.
  std::unordered_map<IterDomain*, std::unordered_set<int64_t>>
      constant_extent_map_;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch