1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
#include <memory>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
//! Summary of interesting facts about the kernel
//!
//! TODO(kir): const node ptrs
//!
struct KernelSummary {
//! List of Write-After-Read (WAR) synchronization barriers
std::unordered_map<size_t, kir::Sync*> war_hazard_syncs;
//! List of global buffers
std::vector<kir::Allocate*> global_allocations;
//! List of dynamic shared memory buffers
std::vector<kir::Allocate*> dynamic_smem_allocations;
//! List of static shared memory buffers
std::vector<kir::Allocate*> static_smem_allocations;
//! Indicate the need to generate random numbers
bool is_stochastic = false;
//! Do we have any block reductions?
bool has_block_reductions = false;
//! Do we have any grid reductions?
bool has_grid_reductions = false;
//! Do we have any block broadcasts?
bool has_block_broadcasts = false;
//! Largest shared memory buffer base type
DataType largest_smem_data_type = DataType::Null;
};
//! Container for a lowered Kernel IR
//!
//! TODO(kir): currently, it is just pointing to nodes owned
//! by a Fusion object. The goal is to have the Kernel object
//! own the Kernel IR nodes
//!
class TORCH_CUDA_API Kernel final : public NonCopyable {
public:
Kernel() = default;
//! Finalize a kernel definition
//!
//! At this point we have a complete kernel definition and we can
//! run analysis passes to build a KernelSummary
//!
void finalize(
std::vector<Expr*> top_level_exprs,
ThreadPredicateMap predicate_map);
//! Register input as an input of the kernel
void addInput(Val* input) {
inputs_.push_back(input);
}
//! Register output as an output of the kernel
void addOutput(Val* output) {
outputs_.push_back(output);
}
const auto& inputs() const {
return inputs_;
}
const auto& outputs() const {
return outputs_;
}
const auto& topLevelExprs() const {
return top_level_exprs_;
}
const KernelSummary& summary() const {
return summary_;
}
const ThreadPredicateMap& predicateMap() const {
return *predicate_map_;
}
//! Register a new Kernel IR node
//!
//! \note This is a specialized helper for kir::IrBuilder, not
//! intendted for general use
//!
void registerIrNode(std::unique_ptr<Statement> node) {
ir_nodes_.push_back(std::move(node));
}
private:
// Analyze the kernel IR and caches the summary of interesting data
void analyze();
private:
// Kernel IR nodes
std::vector<std::unique_ptr<Statement>> ir_nodes_;
// Map from value to its definition expression
std::unordered_map<const Val*, Expr*> definitions_;
// Top level expressions
std::vector<Expr*> top_level_exprs_;
// Kernel inputs and outputs
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;
// Summary of interesting kernel data
KernelSummary summary_;
// Predicate map
// TODO(kir): consider a simpler, kernel IR based version
std::unique_ptr<ThreadPredicateMap> predicate_map_;
};
} // namespace fuser
} // namespace jit
} // namespace torch
|