1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
|
#pragma once
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/mma_type.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
//! Starting point for a matmul scheduler parameters:
class MatmulParam {
public:
MatmulParam(MmaBuilder builder) : mma_builder(builder) {}
struct DoubleBufferOptions {
bool double_buffer_smem_write = false;
bool double_buffer_smem_read = false;
int smem_double_buffer_stage = 2;
};
//! (Ampere+) Use cp.async to load operands.
bool async_gmem_load_operands = false;
//! Specifies the tiling hierarchy on block,
//! warp, and instruction levels.
MatMulTileOptions tile_sizes;
//! Parameters for configuring mma ops.
MmaBuilder mma_builder;
//! Specify which tensor we double buffer.
DoubleBufferOptions double_buffer_options;
};
//! Prototype auto scheduling function.
//! Currently only support a pure matmul with no
//! fused prolog or epilog.
//!
//! TODO:
//! - will support a range of fusions in a follow up
//! - will formalize scheduling decisions into
//! matmul params data structure.
TORCH_CUDA_CU_API void scheduleMatmul(
TensorView* c_tv,
TensorView* a_tv,
TensorView* b_tv,
MatmulParam& params);
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|