1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
|
#include <torch/csrc/jit/codegen/cuda/scheduler/matmul.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace {
// Move the broadcast axes to the left on the specified number of inner
// dimensions e.g. (when number_of_inner_pos == 3):
// [... I0, B, I1] -> [... B, I0, I1]
// should probably be only used to order innermost mnk axes.
void moveInnerBroadcastLeft(TensorView* tv, int number_of_inner_pos = 3) {
TORCH_INTERNAL_ASSERT(tv->nDims() >= number_of_inner_pos);
std::vector<int> broadcast_pos;
std::vector<int> nonbroadcast_pos;
for (auto i : c10::irange(number_of_inner_pos)) {
auto axis_idx = i - number_of_inner_pos;
auto id = tv->axis(axis_idx);
if (id->isBroadcast()) {
broadcast_pos.push_back(axis_idx);
} else {
nonbroadcast_pos.push_back(axis_idx);
}
}
auto combined_pos_vec = broadcast_pos;
combined_pos_vec.insert(
combined_pos_vec.end(), nonbroadcast_pos.begin(), nonbroadcast_pos.end());
std::unordered_map<int, int> order_map;
for (auto i : c10::irange(number_of_inner_pos)) {
order_map[combined_pos_vec.at(i)] = i - number_of_inner_pos;
}
// Apply ordering.
tv->reorder(order_map);
}
} // namespace
void scheduleMatmul(
TensorView* c,
TensorView* a,
TensorView* b,
MatmulParam& params) {
// Unpack from params.
auto& mma_builder = params.mma_builder;
auto& gemm_tile = params.tile_sizes;
// Including current tensor naming convention for reference,
// this is very temporary and will change over time and
// in fact the whole body of this function will
// eventually be a set of utility functions for different
// sections of matmul(fusion) kernels, with
// each having its own build out to do.
//
// Current naming convention:
//
// operands assumed in global memory : a, b
//
// registers staging global load : ar, br (short for a/b read)
//
// shared mem cache of operands : acw_smem, bcw_smem (short for a/b
// cache_write smem)
//
// registers at shared memory load output : acr, bcr (short for a/b cache
// read)
//
// register tensor input to the actual mma op: ab, bb (short for a/b
// broadcasted)
//
// accumulator register: cc (short for c cache)
//
// result in global memory: c
// Currently only support a, b, c as fusion inputs/outputs
// aka. no prolog and epilog fusion yet.
TORCH_CHECK(
c->isFusionOutput() && a->isFusionInput() && b->isFusionInput(),
"not supporting matmul fusion yet");
TORCH_CHECK(c->definition() && c->definition()->isA<MmaOp>());
mma_builder.configureMma(c);
// TODO:
// Beyond this point, mma_builder really just becomes a populated
// list of parameters to describes the mma swizzles that should
// be annotated on the tensor domain. Conceptually the mma builder
// object should be separated to 2 parts, one as scheduler utility
// and the other as matmul heuristic parameters, which we are
// starting to build out.
// Setup register and shared memory stages:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// Setup accumulator register.
auto cc = c->cacheBefore();
// Get the input to the mma op.
auto mma = dynamic_cast<MmaOp*>(cc->definition());
TORCH_INTERNAL_ASSERT(mma != nullptr);
auto ab = mma->inA()->as<TensorView>();
auto bb = mma->inB()->as<TensorView>();
// Get exact configurations from mma builder.
mma_builder.accumulatorTv(cc);
auto mma_options = mma_builder.build();
// Staging register for global memory load
TensorView *ar = a, *br = b;
if (!params.async_gmem_load_operands) {
ar = a->cacheAfter();
br = b->cacheAfter();
}
// TODO:
// Significant build out needed here
// for more flexibility and data type support.
// Shared memory
TensorView* acw_smem = nullptr;
TensorView* bcw_smem = nullptr;
// Shared memory read
TensorView* acr = nullptr;
TensorView* bcr = nullptr;
// Different paths because Volta swizzle needs to
// involve the broadcast dimensions that are concretized
// at mma, while Ampere ones should be done before
// the broadcast op to be able to use cp.async.
// TODO:
// Also a few additional parameters should be introduced
// to control this stage of scheduling.
if (isVolta(mma_options.macro)) {
acw_smem = ab->cacheAfter();
bcw_smem = bb->cacheAfter();
// Cache again to be able to vectorize.
acw_smem = acw_smem->cacheAfter();
bcw_smem = bcw_smem->cacheAfter();
acr = acw_smem->cacheAfter();
bcr = bcw_smem->cacheAfter();
if (params.double_buffer_options.double_buffer_smem_read) {
// Provide another copy op between the double buffered
// smem load register and the actual mma ops to avoid
// complication in double buffered fragment iteration.
ab = acr->cacheAfter();
bb = bcr->cacheAfter();
} else {
ab = acr;
bb = bcr;
}
} else {
// Use cp.async as requested in scheduler params.
c10::optional<LoadStoreOpType> load_op = c10::nullopt;
if (params.async_gmem_load_operands) {
load_op = LoadStoreOpType::CpAsync;
}
acw_smem = ar->cacheAfter(load_op);
bcw_smem = br->cacheAfter(load_op);
acr = acw_smem->cacheAfter(
mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
bcr = bcw_smem->cacheAfter(
mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
}
// Make a CTA tile
// ------------------------------------------------------------------
scheduler_utils::matmul_utils::canonicalizeMmaTvOrdering(cc);
// [... M,N,K]
scheduler_utils::matmul_utils::makeTile(cc, gemm_tile.cta_tile.toVector());
// [Mo, No, Ko, Mi, Ni, Ki]
// Propagate tiling globally
scheduler_utils::transformPropagateToAllFrom(cc, -1);
// Schedule warp tile
scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(cc, gemm_tile);
// Propagate warp tile to main loop and epilog/output tvs
scheduler_utils::BoundedDirectionalTransformPropagator::bothWays(
cc, -1, {acw_smem, bcw_smem}, {c});
// Schedule prolog:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// ------------------------------------------------------------------
scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(acw_smem);
// [... M, K]
acw_smem->merge(-2);
scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
acw_smem, gemm_tile, 8, false);
// [... N, K]
scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(bcw_smem);
bcw_smem->merge(-2);
scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
bcw_smem, gemm_tile, 8, false);
// Propagate prolog tensors
// propagate up the DAG, and propagate parallel type.
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
acw_smem,
-1,
{a},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
bcw_smem,
-1,
{b},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
// Set computeAt, setup the loop nesting structure on the kernel.
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// ------------------------------------------------------------------
// CTA tile:
// Swizzle block tiles:
c->swizzle(Swizzle2DType::ZShape, 0, 1, SwizzleMode::Loop);
a->computeAt(c, 2);
b->computeAt(c, 2);
// Prolog:
a->computeAt(cc, 3);
b->computeAt(cc, 3);
// Main Loop:
acr->computeAt(cc, -6);
bcr->computeAt(cc, -6);
// Add mma swizzle:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// ------------------------------------------------------------------
if (isTuring(mma_options.macro) || isAmpere(mma_options.macro)) {
moveInnerBroadcastLeft(ab);
moveInnerBroadcastLeft(bb);
}
ab->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
bb->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
// Propagate mma input swizzle up the DAG
// to all the tensors before mma op and after shared mem read.
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
ab,
-1,
{acw_smem},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
bb,
-1,
{bcw_smem},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
cc->applyMmaSwizzle(
mma_builder.operand(MmaOptions::Operand::Accumulator).build());
// Set memory type:
acw_smem->setMemoryType(MemoryType::Shared);
bcw_smem->setMemoryType(MemoryType::Shared);
// Set parallelization:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// ------------------------------------------------------------------
// Vectorize smem stores/loads:
acw_smem->axis(-1)->parallelize(ParallelType::Vectorize);
bcw_smem->axis(-1)->parallelize(ParallelType::Vectorize);
acr->axis(-1)->parallelize(ParallelType::Vectorize);
bcr->axis(-1)->parallelize(ParallelType::Vectorize);
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
cc->axis(0)->parallelize(ParallelType::BIDx);
cc->axis(1)->parallelize(ParallelType::BIDy);
cc->axis(3)->parallelize(ParallelType::TIDz);
cc->axis(4)->parallelize(ParallelType::TIDy);
// Propagate mma output swizzle and parallelization down the DAG
if (params.double_buffer_options.double_buffer_smem_write) {
TORCH_CHECK(
params.double_buffer_options.smem_double_buffer_stage > 1,
"Invalid buffer stage config")
if (params.double_buffer_options.smem_double_buffer_stage > 2) {
TORCH_CHECK(
params.async_gmem_load_operands,
"Circular buffer only supports async load");
}
acw_smem->circularBuffer(
params.double_buffer_options.smem_double_buffer_stage);
bcw_smem->circularBuffer(
params.double_buffer_options.smem_double_buffer_stage);
}
if (params.double_buffer_options.double_buffer_smem_read) {
acr->doubleBuffer();
bcr->doubleBuffer();
}
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
cc,
-1,
{c},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|