1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
|
// NOTE: Like interface.cpp, this file will be copied into AOTInductor
// generated output. This file is intended to keep implementation
// details separate from the implementation of the AOTI public
// interface.
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
namespace torch {
namespace aot_inductor {
template <typename T>
void convert_output_to_handle(
const ArrayRefTensor<T>& output,
AtenTensorHandle& handle) {
handle = output.expensiveCopyToTensor();
}
template <typename... Ts, std::size_t... Is>
void convert_outputs_to_handles_helper(
const std::tuple<ArrayRefTensor<Ts>...>& outputs,
AtenTensorHandle* output_handles,
std::index_sequence<Is...>) {
(convert_output_to_handle(std::get<Is>(outputs), output_handles[Is]), ...);
}
template <typename... Ts>
void convert_outputs_to_handles(
const std::tuple<ArrayRefTensor<Ts>...>& outputs,
AtenTensorHandle* output_handles) {
convert_outputs_to_handles_helper(
outputs, output_handles, std::make_index_sequence<sizeof...(Ts)>());
}
template <typename T>
void convert_handle_to_arrayref_tensor(
AtenTensorHandle handle,
ArrayRefTensor<T>& input) {
void* data_ptr;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr));
int64_t dim;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim));
int64_t numel;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel));
int64_t* sizes;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes));
int64_t* strides;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides));
int32_t dtype;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype));
int32_t device_type;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type));
int32_t device_index;
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(handle, &device_index));
input = ArrayRefTensor<T>(
MiniArrayRef<T>(reinterpret_cast<T*>(data_ptr), numel),
MiniArrayRef<const int64_t>(sizes, dim),
MiniArrayRef<const int64_t>(strides, dim),
device_type,
device_index);
}
template <typename... Ts, std::size_t... Is>
void convert_handles_to_inputs_helper(
AtenTensorHandle* input_handles,
std::tuple<ArrayRefTensor<Ts>...>& inputs,
std::index_sequence<Is...>) {
(convert_handle_to_arrayref_tensor(input_handles[Is], std::get<Is>(inputs)),
...);
}
template <typename... Ts>
void convert_handles_to_inputs(
AtenTensorHandle* input_handles,
std::tuple<ArrayRefTensor<Ts>...>& inputs) {
convert_handles_to_inputs_helper(
input_handles, inputs, std::make_index_sequence<sizeof...(Ts)>());
}
template <typename T>
void assert_numel(const ArrayRefTensor<T>& tensor, uint64_t numel) {
if (tensor.numel() != numel) {
std::stringstream err;
err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel();
throw std::runtime_error(err.str());
}
}
} // namespace aot_inductor
} // namespace torch
|