1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <algorithm>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
/*
* compute_at is a relative property between two TensorViews which marks at what
* iteration domain we're going to generate a tensor to be consumed by another.
* For example if we have: T2[I, J, K] = T1[I, J, K] * 2.0 and then we call
* T2.split(axis = 0, factor = ...): T2[Io, Ii, J, K] = T1[I, J, K] * 2.0 where
* Io is the outer axes from the split, and Ii is the inner axes from the split.
* then we call T1.compute_at(T2, axis=1) we would expect to have:
* T2[Io, Ii, J, K] = T1[Io, Ii, J, K] * 2.0
* which would produce the following loop nest structure:
*
* for(io : Io)
* for(ii : Ii)
* for(j : J)
* for(k : K)
* //produce T1:
* T1[io, ii, j, k] = ...
* for(ii : Ii)
* for(j : J)
* for(k : K)
* //consume T1, produce T2
* T2[io, ii, j, k] = T1[io, ii, j, k] * 2.0
*
* This file provides the replay function that allows us to construct T1's
* domain from T2 at a desired level (compute_at_axis) without modifying any
* unnecessary parts of the domain.
*
* EXAMPLES:
*
* ANOTHER ITER EXAMPLE:
* T2[I, J, K] = T1[I, J, K] * 2.0
* T2.split(axis = 0, factor = ...)
* T2[Io, Ii, J, K] = T1[I, J, K] * 2.0
* T2.split(axis = 2, factor = ...)
* T2[Io, Ii, Jo, Ji, K] = T1[I, J, K] * 2.0
* T1.compute_at(T2, axis=1)
* T2[Io, Ii, Jo, Ji, K] = T1[Io, Ii, J, K] * 2.0
*
* Note: compute_at axis:
* T2[ 0 Io, 1 Ii, 2 Jo, 3 Ji, 4 K 5 ] //5 is inline, 0 is at "root" which means
* completely separate loop nests.
*
* for(io : Io)
* for(ii : Ii)
* for(j : J)
* for(k : K)
* //produce T1, this is the view that replay generates:
* T1[io, ii, j, k] = ...
* for(ii : Ii)
* for(jo : Jo)
* for(ji : Ji)
* for(k : K)
* //consume T1, produce T2
* T2[io, ii, jo, ji, k] = T1[io, ii, jo, ji, k] * 2.0
* //consumer view on T1 will be produced at a later stage.
*
*
* SIMPLE REDUCTION EXAMPLE:
* T1[I, J, K] = ...
* T2[I, R, K] = T1[I, J, K] //.sum(axis = 1), we reduce on R/J to produce
* T2[I, K] T2.split(axis = 0, factor = ...) T2[Io, Ii, R, K] = T1[I, J, K]
* T1.compute_at(T2, axis=3)
* T2[Io, Ii, R, K] = T1[Io, Ii, J, K]
*
* for(io : Io)
* for(ii : Ii)
* for(k : K)
* T2[io, ii, k] = init
* for(r : R)
* for(k : K)
* //produce T1:
* T1[io, ii, r, k] = ...
* //consume T1 produce T2:
* T2[io, ii, k] += T1[io, ii, r, k]
*
*
* REDUCTION EXAMPLE RESULTING IN AN ERROR:
* T1[I, R, K] = ... //R is reduction domain, we reduce on R to produce T1[I,
* K] T2[I, K] = T1[I, K]
*
* for(i : I)
* for(k : K)
* T1[i, k] = init
* for(r : R)
* for(k : K)
* T1[i, k] += ...[i, r, k]
* for(i : I)
* for(k : K)
* T2[i, k] = T1[i, k]
*
* T1.compute_at(T2, axis=2)
* This should be an error, or a warning and changed to:
* T1.compute_at(T2, axis=1)
* The error is because the kernel would have to be:
*
* for(i : I)
* T1[i, k] = init
* for(r : R)
* for(k : K)
* T1[i, k] += ...[i, r, k]
* for(k : K)
* T2[i, k] = T1[i, k]
*
* Otherwise we would produce incorrect results.
*
*/
class TensorDomain;
class TensorView;
class TORCH_CUDA_API TransformReplay {
public:
// Replay producer as consumer, returns {producer, producer_compute_at_axis}.
static std::pair<TensorDomain*, unsigned int> replayPasC(
const TensorDomain* producer,
const TensorDomain* consumer,
int consumer_compute_at_axis);
// Replay producer as consumer, returns {producer, producer_compute_at_axis}.
static std::pair<TensorView*, unsigned int> replayPasC(
TensorView* producer,
TensorView* consumer,
int consumer_compute_at_axis);
// Replay producer as consumer, returns {consumer, consumer_compute_at_axis}.
static std::pair<TensorDomain*, unsigned int> replayCasP(
const TensorDomain* consumer,
const TensorDomain* producer,
int producer_compute_at_axis);
// Replay producer as consumer, returns {consumer, consumer_compute_at_axis}.
static std::pair<TensorView*, unsigned int> replayCasP(
TensorView* consumer,
TensorView* producer,
int producer_compute_at_axis);
// Self replay.
static TensorDomain* fullSelfReplay(
const TensorDomain* new_self_root,
const TensorDomain* self);
};
} // namespace fuser
} // namespace jit
} // namespace torch
|