1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
|
#include <torch/csrc/lazy/ts_backend/tensor_aten_ops.h>
#include <ATen/InferSize.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/ir_util.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/metrics.h>
#include <torch/csrc/lazy/core/ops/arithmetic_ir_ops.h>
#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/core/tensor.h>
#include <torch/csrc/lazy/core/util.h>
#include <torch/csrc/lazy/generated/LazyIr.h>
#include <algorithm>
#include <functional>
#include <optional>
namespace torch::lazy {
namespace {
// to enable operator+-*/ for Value
using namespace torch::lazy;
torch::lazy::Value MaybeExpand(
const torch::lazy::Value& input,
const torch::lazy::Shape& target_shape) {
if (input.shape().sizes() == target_shape.sizes()) {
return input;
}
return torch::lazy::MakeExpand(
input,
target_shape.sizes().vec(),
/*is_scalar_expand=*/false);
}
} // namespace
//////////////////////////////////////////////////////////////////////////////
// ATEN operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value) {
torch::lazy::Value constant =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar(
value, input->shape(), input->GetDevice());
input->SetInPlaceIrValue(std::move(constant));
}
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
if (input->GetDevice() == src->GetDevice()) {
torch::lazy::Value copy_value;
if (input->dtype() == src->dtype()) {
copy_value = src->GetIrValue();
} else {
copy_value = torch::lazy::MakeCast(
src->GetIrValue(), input->dtype(), src->dtype());
}
input->SetIrValue(MaybeExpand(copy_value, input->shape()));
} else {
auto input_shape = input->shape();
at::Tensor src_tensor = src->ToTensor(/*detached=*/true);
if (src_tensor.sizes() != input_shape.Get().sizes()) {
src_tensor = src_tensor.expand(input_shape.Get().sizes().vec());
}
input->UpdateFromTensor(src_tensor, /*sync=*/false);
}
}
} // namespace torch::lazy
|