File: kernel.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (136 lines) | stat: -rw-r--r-- 3,342 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

#pragma once

#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>

#include <memory>
#include <utility>
#include <vector>

namespace torch {
namespace jit {
namespace fuser {

//! Summary of interesting facts about the kernel
//!
//! TODO(kir): const node ptrs
//!
struct KernelSummary {
  //! List of Write-After-Read (WAR) synchronization barriers
  std::unordered_map<size_t, kir::Sync*> war_hazard_syncs;

  //! List of global buffers
  std::vector<kir::Allocate*> global_allocations;

  //! List of dynamic shared memory buffers
  std::vector<kir::Allocate*> dynamic_smem_allocations;

  //! List of static shared memory buffers
  std::vector<kir::Allocate*> static_smem_allocations;

  //! Indicate the need to generate random numbers
  bool is_stochastic = false;

  //! Do we have any block reductions?
  bool has_block_reductions = false;

  //! Do we have any grid reductions?
  bool has_grid_reductions = false;

  //! Do we have any block broadcasts?
  bool has_block_broadcasts = false;

  //! Largest shared memory buffer base type
  DataType largest_smem_data_type = DataType::Null;
};

//! Container for a lowered Kernel IR
//!
//! TODO(kir): currently, it is just pointing to nodes owned
//!  by a Fusion object. The goal is to have the Kernel object
//!  own the Kernel IR nodes
//!
class TORCH_CUDA_API Kernel final : public NonCopyable {
 public:
  Kernel() = default;

  //! Finalize a kernel definition
  //!
  //! At this point we have a complete kernel definition and we can
  //! run analysis passes to build a KernelSummary
  //!
  void finalize(
      std::vector<Expr*> top_level_exprs,
      ThreadPredicateMap predicate_map);

  //! Register input as an input of the kernel
  void addInput(Val* input) {
    inputs_.push_back(input);
  }

  //! Register output as an output of the kernel
  void addOutput(Val* output) {
    outputs_.push_back(output);
  }

  const auto& inputs() const {
    return inputs_;
  }

  const auto& outputs() const {
    return outputs_;
  }

  const auto& topLevelExprs() const {
    return top_level_exprs_;
  }

  const KernelSummary& summary() const {
    return summary_;
  }

  const ThreadPredicateMap& predicateMap() const {
    return *predicate_map_;
  }

  //! Register a new Kernel IR node
  //!
  //! \note This is a specialized helper for kir::IrBuilder, not
  //!   intendted for general use
  //!
  void registerIrNode(std::unique_ptr<Statement> node) {
    ir_nodes_.push_back(std::move(node));
  }

 private:
  // Analyze the kernel IR and caches the summary of interesting data
  void analyze();

 private:
  // Kernel IR nodes
  std::vector<std::unique_ptr<Statement>> ir_nodes_;

  // Map from value to its definition expression
  std::unordered_map<const Val*, Expr*> definitions_;

  // Top level expressions
  std::vector<Expr*> top_level_exprs_;

  // Kernel inputs and outputs
  std::vector<Val*> inputs_;
  std::vector<Val*> outputs_;

  // Summary of interesting kernel data
  KernelSummary summary_;

  // Predicate map
  // TODO(kir): consider a simpler, kernel IR based version
  std::unique_ptr<ThreadPredicateMap> predicate_map_;
};

} // namespace fuser
} // namespace jit
} // namespace torch