1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
#pragma once
#include <ATen/Tensor.h>
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>
#include <torch/csrc/lazy/core/tensor.h>
#include <atomic>
namespace torch {
namespace lazy {
struct IrBuilder;
/**
* Work in progress- don't treat this as a stable interface yet!
*/
class TORCH_API BackendImplInterface {
public:
virtual ~BackendImplInterface() = default;
/**
* Initialization/Teardown
* */
// No-op by default. Allows custom functionality to be exposed through
// extension bindings.
virtual void InitializeAtenBindings() const {}
virtual void PrepareToExit() const = 0;
/**
* Configuration
* */
virtual void SetRngSeed(size_t seed) const = 0;
/**
* IR Tracing
* */
virtual const IrBuilder* GetIrBuilder() const = 0;
virtual bool ShouldSyncTensor(const LazyTensorPtr tensor) const;
/**
* Data Transfer
* */
virtual BackendDataPtr MakeComputationDataFromTensor(
const at::Tensor& tensor,
const Shape& shape,
const BackendDevice& device) const = 0;
virtual BackendDataPtr MakeComputationDataFromScalar(
const at::Scalar& scalar,
const torch::lazy::BackendDevice& device) const = 0;
virtual BackendDataPtr CreateDataPlaceholder(
const BackendDevice& device,
const Shape& shape) const = 0;
// Gets backend data if the node is a device data node. Otherwise returns
// nullptr
virtual BackendDataPtr GetComputationDataFromNode(Node*) const = 0;
virtual at::Tensor MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const = 0;
/**
* Lowering, Compilation, Execution
* */
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
const std::string& name,
BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
Util::EmissionMap emit_status) const = 0;
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
const std::string& name,
BackendDevice device) const = 0;
// TODO(whc) need to keep this?
virtual std::vector<std::string> GetCompilationDevices(
const std::string& device,
c10::ArrayRef<std::string> devices) const = 0;
virtual std::vector<ComputationPtr> Compile(
std::vector<ComputationPtr> instances) const = 0;
virtual std::vector<BackendDataPtr> ExecuteComputation(
torch::lazy::ComputationPtr computation,
c10::ArrayRef<BackendDataPtr> arguments,
const BackendDevice& device) const = 0;
/**
* Device Configuration
* */
// Set or get the default device type.
// For backends used with virtual c10::Devices, this configures what real
// device type the backend should use, and matters if the backend supports
// more than one type of real device.
virtual std::shared_ptr<BackendDeviceType> GetDefaultDeviceType() const = 0;
virtual void SetDefaultDeviceType(int8_t type) = 0;
// Set or get the default device ordinal.
// For backends that supports multi-device, this configures what the
// default device the backend should use.
virtual int64_t GetDefaultDeviceOrdinal() const = 0;
virtual void SetDefaultDeviceOrdinal(int64_t) = 0;
// Specify which aten device should be used for eager fallback
// may change depending on current 'Default' DeviceType
virtual at::DeviceType EagerFallbackDeviceType() const = 0;
// Query all available backend devices
virtual std::vector<BackendDevice> GetBackendDevices() const = 0;
virtual std::string CreateMetricReport() const {
return "";
}
// Map a particular c10:: device to a concrete backend device
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
// through a mode, in which case these APIs should still work, but should be
// identity mappings.
virtual BackendDevice GetBackendDevice(c10::Device device) const = 0;
// TODO(whc)
// Additional APIs expected for supporting distributed training, to be
// designed
/**
* Debug/Metrics
* */
// virtual std::map<std::string, Metric> GetMetrics() const = 0;
// virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0;
virtual std::string GetComputationBackendText(
const ComputationPtr computation) const = 0;
};
class TORCH_API BackendRegistrar {
public:
BackendRegistrar(const BackendImplInterface* backend_impl_interface);
};
TORCH_API bool hasBackend();
TORCH_API const BackendImplInterface* getBackend();
TORCH_API const IrBuilder* getIrBuilder();
} // namespace lazy
} // namespace torch
|