1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
#pragma once
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
class TORCH_API DeviceData : public TsNode {
public:
static OpKind ClassOpKind() {
return ltc_device_data;
}
explicit DeviceData(std::shared_ptr<BackendData> data);
// A DeviceData node can be reused if the shape matches,
// but we will substitute the actual data_ pointer under
// the hood.
bool CanBeReused(std::shared_ptr<BackendData> data) const {
return data_->shape() == data->shape();
}
std::string ToString() const override;
const std::shared_ptr<BackendData>& data() const {
return data_;
}
void SetData(std::shared_ptr<BackendData> data) {
data_ = data;
}
static const DeviceData* Cast(const Node* node);
// To reuse IR nodes, use this method to create DeviceData nodes
// instead of calling the constructor directly.
static NodePtr Create(std::shared_ptr<BackendData> data);
TSOpVector Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const override;
private:
std::shared_ptr<BackendData> data_;
};
} // namespace lazy
} // namespace torch
|