1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
|
#pragma once
#include <torch/csrc/lazy/backend/backend_interface.h>
namespace torch {
namespace lazy {
class TORCH_API TSData : public torch::lazy::BackendData {
public:
TSData(const at::Scalar& scalar, const torch::lazy::BackendDevice& device)
: torch::lazy::BackendData(device, torch::lazy::Shape(scalar.type(), {})),
scalar(scalar) {}
TSData(
const at::Tensor& data,
const torch::lazy::Shape& shape,
const torch::lazy::BackendDevice& device)
: torch::lazy::BackendData(device, shape), data_(data) {}
TSData(
const torch::lazy::Shape& shape,
const torch::lazy::BackendDevice& device)
: torch::lazy::BackendData(device, shape) {}
Handle GetHandle() override {
return reinterpret_cast<int64_t>(this);
}
void Assign(const torch::lazy::BackendData& data) override {
data_ = static_cast<const TSData&>(data).data_;
}
bool HasValue() const override {
return data_.defined();
}
at::Tensor data() {
return data_;
}
c10::optional<at::Scalar> scalar;
private:
at::Tensor data_;
};
TORCH_API torch::lazy::BackendImplInterface* GetTSBackendImpl();
TORCH_API void InitTorchScriptBackend();
} // namespace lazy
} // namespace torch
|