1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
#pragma once
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/shape_inference.h>
#include <torch/csrc/lazy/generated/LazyNonNativeIr.h>
#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
#include <torch/csrc/lazy/ts_backend/ops/device_data.h>
#include <torch/csrc/lazy/ts_backend/ops/generic.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
struct TorchScriptIrBuilder : IrBuilder {
NodePtr MakeDeviceData(
const std::shared_ptr<BackendData>& data) const override {
return DeviceData::Create(data);
}
// TODO: Scalar node is not currently used by ts_backend. Enable reusing
// Scalar node later if needed.
NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type)
const override {
return MakeNode<Scalar>(value, type);
}
NodePtr MakeExpand(
const Value& input0,
const std::vector<int64_t>& size,
const bool& is_scalar_expand) const override {
return ReuseOrMakeNode<Expand>(input0, size, is_scalar_expand);
}
NodePtr MakeView(const Value& input0, const std::vector<int64_t>& output_size)
const override {
return ReuseOrMakeNode<View>(input0, output_size);
}
NodePtr MakeCast(
const Value& input0,
const at::ScalarType& dtype,
const c10::optional<at::ScalarType>& stype =
c10::nullopt) const override {
return ReuseOrMakeNode<Cast>(input0, dtype, stype);
}
NodePtr MakeTensorList(const OpList& inputs) const override {
return ReuseOrMakeNode<TensorList>(inputs);
}
// Generic needs cleanup
NodePtr MakeGeneric(
const OpKind& op,
const OpList& operands,
const Shape& shape,
const size_t& num_outputs = 1,
const hash_t& hash_seed =
static_cast<uint32_t>(0x5a2d296e9)) const override {
return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed);
}
// View op nodes
NodePtr MakeAsStridedViewUpdate(
const Value& input0,
const Value& input1,
const std::vector<int64_t>& size,
const std::vector<int64_t>& stride,
const int64_t& storage_offset) const override {
return ReuseOrMakeNode<AsStridedViewUpdate>(
input0, input1, size, stride, storage_offset);
}
NodePtr MakeAsStrided(
const Value& input0,
const std::vector<int64_t>& size,
const std::vector<int64_t>& stride,
const int64_t& storage_offset) const override {
return ReuseOrMakeNode<AsStrided>(input0, size, stride, storage_offset);
}
NodePtr MakeDiagonalViewUpdate(
const Value& input0,
const Value& input1,
const int64_t& offset,
const int64_t& dim1,
const int64_t& dim2) const override {
return ReuseOrMakeNode<DiagonalViewUpdate>(
input0, input1, offset, dim1, dim2);
}
NodePtr MakeDiagonal(
const Value& input0,
const int64_t& offset,
const int64_t& dim1,
const int64_t& dim2) const override {
return ReuseOrMakeNode<Diagonal>(input0, offset, dim1, dim2);
}
NodePtr MakeNarrowViewUpdate(
const Value& input0,
const Value& input1,
const std::vector<int64_t>& base_indices) const override {
return ReuseOrMakeNode<NarrowViewUpdate>(input0, input1, base_indices);
}
NodePtr MakeNarrow(
const Value& input0,
const std::vector<int64_t>& base_indices,
const std::vector<int64_t>& sizes) const override {
return ReuseOrMakeNode<Narrow>(input0, base_indices, sizes);
}
NodePtr MakePermute(const Value& input0, const std::vector<int64_t>& dims)
const override {
return ReuseOrMakeNode<Permute>(input0, dims);
}
NodePtr MakeResize(const Value& input0, const std::vector<int64_t>& size)
const override {
return ReuseOrMakeNode<Resize>(input0, size);
}
NodePtr MakeSelectViewUpdate(
const Value& input0,
const Value& input1,
const int64_t& dim,
const int64_t& start,
const int64_t& end,
const int64_t& stride) const override {
return ReuseOrMakeNode<SelectViewUpdate>(
input0, input1, dim, start, end, stride);
}
NodePtr MakeSelect(
const Value& input0,
const int64_t& dim,
const int64_t& start,
const int64_t& end,
const int64_t& stride) const override {
return ReuseOrMakeNode<Select>(input0, dim, start, end, stride);
}
NodePtr MakeSqueeze(const Value& input0, const int& dim) const override {
return ReuseOrMakeNode<Squeeze>(input0, dim);
}
NodePtr MakeUnsqueeze(const Value& input0, const int& dim) const override {
return ReuseOrMakeNode<Unsqueeze>(input0, dim);
}
// dynamic ir nodes
// TODO: verify if IR node reusing works for Dynamic shape ops
NodePtr MakeSizeNode(const Value& input, size_t dim) const override {
return MakeNode<SizeNode>(input, dim);
}
NodePtr MakeSizeAdd(const Value& a, const Value& b) const override {
return MakeNode<SizeAdd>(a, b);
}
NodePtr MakeSizeMul(const Value& a, const Value& b) const override {
return MakeNode<SizeMul>(a, b);
}
NodePtr MakeSizeDiv(const Value& a, const Value& b) const override {
return MakeNode<SizeDiv>(a, b);
}
};
} // namespace lazy
} // namespace torch
|