1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
|
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/ops/alias.h>
#include <torch/csrc/jit/codegen/cuda/transform_view.h>
#include <torch/csrc/jit/codegen/cuda/type_promotion.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace {
//! Transform TensorView according to keep, merge, and split transformations.
//! Trivial reduction and broadcast transformations are handled separately.
//! It is recommend to use the composite ops view function, which will call
//! the analyzeView function to generate the appropriate transformations.
//!
//! For example:
//! original sizes = [2, 10, 40]
//! new_size = [2, 10, 2, 20]
//! auto analysis = analyzeView(TV0, original_sizes, new_sizes)
//! auto TV1 = TV0->view(analysis.transforms);
//!
//! Transforms = [(Keep I0), (Keep I1), (Split I2 by 2)]
//! Before: TV0[I0, I1, I2]
//! After: TV0[I0, I1, 2, ceilDiv(I2, 2)]
//!
//! orig_tv is the tensor view originally coming in from user for the view
//! operation. This is the tensor view all of the view analysis is relative to.
//! View might be doing trivial reductions before sending into the view
//! operation, so we want the actual input to the view operation to be
//! potentially after the original view operation.
TensorView* applyViewTransforms(
TensorView* orig_tv,
TensorView* post_reduce_tv,
const AnalyzeViewResult& view_analysis) {
TORCH_INTERNAL_ASSERT(
!post_reduce_tv->hasComputeAt(),
"Cannot modify rfactor domain after compute at has been set.");
TORCH_INTERNAL_ASSERT(
post_reduce_tv->nDims() > 0, "Tried to view a 0-dim TensorView");
TORCH_CHECK(
!post_reduce_tv->domain()->hasRFactor(),
"Cannot call view on the same TensorView twice.");
TORCH_INTERNAL_ASSERT(!view_analysis.transforms.empty());
TensorView* consumer = IrBuilder::create<TensorView>(
orig_tv->container(),
orig_tv->domain()->view(view_analysis),
orig_tv->getDataType().value());
IrBuilder::create<ViewOp>(orig_tv->container(), consumer, post_reduce_tv);
return consumer;
}
} // namespace
TensorView* view(TensorView* x, DataType dtype) {
if (x->getDataType() == dtype) {
return x;
}
auto input_type = x->getDataType().value();
auto input_size = dataTypeSize(input_type);
auto newsize = dataTypeSize(dtype);
if (input_size == newsize) {
return bitCastOp(dtype, x);
}
// TODO: support view(dtype) for dtypes where input_size != newsize
TORCH_INTERNAL_ASSERT(false, "Unsupported reinterpret casting view");
}
TensorView* view(
TensorView* x,
const std::vector<int64_t>& original_sizes,
const std::vector<int64_t>& new_sizes) {
TORCH_INTERNAL_ASSERT(
TensorDomain::noReductions(x->getMaybeRFactorDomain()).size() ==
original_sizes.size());
TORCH_INTERNAL_ASSERT(
!original_sizes.empty(), "No support for 0-dim tensors in view support.");
auto view_analysis = analyzeView(x, original_sizes, new_sizes);
auto reduction = (!view_analysis.trivial_reduction_axes.empty())
? sum(x,
view_analysis.trivial_reduction_axes,
false /* keep_dim */,
x->getDataType().value())
: x;
auto view = view_analysis.transforms.empty()
? reduction
: applyViewTransforms(x, reduction, view_analysis);
auto bcasted = std::any_of(
view_analysis.broadcast_axes.begin(),
view_analysis.broadcast_axes.end(),
[](bool b) { return b; })
? broadcast(view, view_analysis.broadcast_axes)
: view;
return bcasted;
}
TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) {
auto inp_domain = TensorDomain::noReductions(x->getMaybeRFactorDomain());
if (start_dim < 0) {
start_dim += inp_domain.size();
}
if (end_dim < 0) {
end_dim += inp_domain.size();
}
TORCH_CHECK(
start_dim >= 0 && start_dim < inp_domain.size(),
"Invalid start_dim ",
start_dim);
TORCH_CHECK(
end_dim >= 0 && end_dim < inp_domain.size(), "Invalid end_dim ", end_dim);
TORCH_CHECK(start_dim <= end_dim, "start_dim must be <= end_dim");
if (start_dim == end_dim) {
return x;
}
auto out = IrBuilder::create<TensorView>(
x->container(),
x->domain()->flatten(start_dim, end_dim),
x->getDataType().value());
IrBuilder::create<ViewOp>(out, x);
return out;
}
TensorView* squeeze(TensorView* x, const std::vector<int64_t>& sizes) {
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
TORCH_INTERNAL_ASSERT(
ndims == sizes.size(),
"Invalid sizes for squeeze: ",
sizes,
". Input tensor: ",
x->toString());
std::vector<int> trivial_reduction_axes;
for (const auto idx : c10::irange(sizes.size())) {
if (sizes[idx] == 1) {
trivial_reduction_axes.push_back(idx);
}
}
return (trivial_reduction_axes.empty()) ? x
: sum(x,
trivial_reduction_axes,
false /* keep_dim */,
x->getDataType().value());
}
TensorView* squeeze(TensorView* x, const std::vector<int64_t>& sizes, int dim) {
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
TORCH_INTERNAL_ASSERT(
ndims == sizes.size(),
"Invalid sizes for squeeze: ",
sizes,
". Input tensor: ",
x->toString());
if (dim < 0) {
dim = ndims + dim;
}
TORCH_INTERNAL_ASSERT(
dim >= 0 && dim < ndims,
"Invalid position to squeeze: ",
dim,
". Input tensor: ",
x->toString());
if (sizes[dim] == 1) {
return sum(x, {dim}, false /* keep_dim */, x->getDataType().value());
} else {
return set(x);
}
}
TensorView* unsqueeze(TensorView* x, int dim) {
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
if (dim < 0) {
dim = ndims + dim + 1;
}
TORCH_INTERNAL_ASSERT(
dim >= 0 && dim <= ndims,
"Invalid position to unsqueeze: ",
dim,
". Input tensor: ",
x->toString());
std::vector<bool> broadcast_axes(ndims + 1, false);
broadcast_axes[dim] = true;
return broadcast(x, broadcast_axes);
}
TensorView* permute(TensorView* x, const std::vector<int64_t>& new2old) {
if (new2old.size() == 0) {
return set(x);
}
auto inp_domain = TensorDomain::noReductions(x->getMaybeRFactorDomain());
std::vector<IterDomain*> out_domain(inp_domain.size());
auto normalized_new2old =
ir_utils::normalizeNew2Old(new2old, inp_domain.size());
for (const auto i : c10::irange(out_domain.size())) {
auto in_id = inp_domain[new2old[i]];
out_domain[i] = in_id->cloneWithoutRFactor();
}
TensorView* out_tensor = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
out_domain, std::vector<bool>(out_domain.size(), true)),
x->getDataType().value());
IrBuilder::create<TransposeOp>(out_tensor, x, normalized_new2old);
return out_tensor;
}
TensorView* transpose(TensorView* x, int64_t dim0, int64_t dim1) {
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
if (dim0 < 0) {
dim0 = ndims + dim0;
}
if (dim1 < 0) {
dim1 = ndims + dim1;
}
TORCH_CHECK(
dim0 >= 0 && dim0 <= ndims, "Invalid transpose dimension 0: ", dim0);
TORCH_CHECK(
dim1 >= 0 && dim1 <= ndims, "Invalid transpose dimension 1: ", dim1);
std::vector<int64_t> new2old(ndims);
for (const auto i : c10::irange(ndims)) {
if (i == dim0) {
new2old[i] = dim1;
} else if (i == dim1) {
new2old[i] = dim0;
} else {
new2old[i] = i;
}
}
return permute(x, new2old);
}
TensorView* transpose(TensorView* x) {
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
TORCH_CHECK(
ndims <= 2,
"Expected a tensor with <= 2 dimensions, but it has ",
ndims,
"D.");
// short-circuit: return original tensorview if less than 2 dimensions
if (ndims < 2) {
return x;
}
return transpose(x, 0, 1);
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|