1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
|
#pragma once
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/shape_inference.h>
#include <torch/csrc/lazy/generated/LazyNonNativeIr.h>
#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
#include <torch/csrc/lazy/ts_backend/ops/device_data.h>
#include <torch/csrc/lazy/ts_backend/ops/generic.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch::lazy {
struct TorchScriptIrBuilder : IrBuilder {
NodePtr MakeDeviceData(
const std::shared_ptr<BackendData>& data) const override {
return DeviceData::Create(data);
}
// TODO: Scalar node is not currently used by ts_backend. Enable reusing
// Scalar node later if needed.
NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type)
const override {
return MakeNode<Scalar>(value, type);
}
NodePtr MakeExpand(
const Value& input0,
const std::vector<int64_t>& size,
const bool& is_scalar_expand) const override {
return ReuseOrMakeNode<Expand>(input0, size, is_scalar_expand);
}
NodePtr MakeCast(
const Value& input0,
const at::ScalarType& dtype,
const std::optional<at::ScalarType>& stype =
std::nullopt) const override {
return ReuseOrMakeNode<Cast>(input0, dtype, stype);
}
NodePtr MakeTensorList(const OpList& inputs) const override {
return ReuseOrMakeNode<TensorList>(inputs);
}
// Generic needs cleanup
NodePtr MakeGeneric(
const OpKind& op,
const OpList& operands,
const Shape& shape,
const size_t& num_outputs = 1,
const hash_t& hash_seed =
static_cast<uint32_t>(0x5a2d296e9)) const override {
return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed);
}
// dynamic ir nodes
// TODO: verify if IR node reusing works for Dynamic shape ops
NodePtr MakeSizeNode(const Value& input, size_t dim) const override {
return MakeNode<SizeNode>(input, dim);
}
NodePtr MakeSizeAdd(const Value& a, const Value& b) const override {
return MakeNode<SizeAdd>(a, b);
}
NodePtr MakeSizeMul(const Value& a, const Value& b) const override {
return MakeNode<SizeMul>(a, b);
}
NodePtr MakeSizeDiv(const Value& a, const Value& b) const override {
return MakeNode<SizeDiv>(a, b);
}
};
} // namespace torch::lazy
|