1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
#include <torch/extension.h>
#include <torch/library.h>
using namespace at;
static int test_int;
Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
Storage(
Storage::use_byte_size_t(),
0,
at::DataPtr(nullptr, Device(DeviceType::MAIA, 0)),
nullptr,
false),
DispatchKey::MAIA,
dtype);
// This is a hack to workaround the shape checks in _convolution.
tensor_impl->set_sizes_contiguous(size);
return Tensor(std::move(tensor_impl));
}
Tensor empty_override(IntArrayRef size, std::optional<ScalarType> dtype, std::optional<Layout> layout, std::optional<Device> device,
std::optional<bool> pin_memory, std::optional<c10::MemoryFormat> optional_memory_format) {
test_int = 0;
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
}
Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) {
test_int = 1;
return out;
}
Tensor fake_convolution(
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
bool transposed, IntArrayRef output_padding, int64_t groups) {
test_int = 2;
// Only the first 2 dimension of output shape is correct.
return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)});
}
std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
const Tensor & grad_output, const Tensor & input, const Tensor & weight,
IntArrayRef stride, IntArrayRef padding,
IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
int64_t groups, std::array<bool,3> output_mask) {
test_int = 3;
return std::tuple<Tensor, Tensor, Tensor>(
get_tensor(input.dtype(), input.sizes()),
get_tensor(weight.dtype(), weight.sizes()),
get_tensor(input.dtype(), {}));
}
at::Tensor maia_to_dtype_override(
const at::Tensor & self, at::ScalarType dtype, bool non_blocking,
bool copy, ::std::optional<at::MemoryFormat> memory_format
) {
return get_tensor(scalarTypeToTypeMeta(dtype), self.sizes());
}
at::Tensor maia_matmul_override(const at::Tensor & self, const at::Tensor & other) {
AT_ASSERT(self.dim() == 2);
AT_ASSERT(other.dim() == 2);
AT_ASSERT(self.dtype() == other.dtype());
AT_ASSERT(self.device() == other.device());
return get_tensor(self.dtype(), {self.size(0), other.size(1)});
}
TORCH_LIBRARY_IMPL(aten, MAIA, m) {
m.impl("empty.memory_format", empty_override);
m.impl("add.out", add_out_override);
m.impl("convolution_overrideable", fake_convolution);
m.impl("convolution_backward_overrideable", fake_convolution_backward);
m.impl("to.dtype", maia_to_dtype_override);
m.impl("matmul", maia_matmul_override);
}
// TODO: Extend this to exercise multi-device setting. In that case,
// we need to add a thread local variable to track the current device.
struct MAIAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::MAIA;
MAIAGuardImpl() {}
MAIAGuardImpl(DeviceType t) {
AT_ASSERT(t == DeviceType::MAIA);
}
DeviceType type() const override {
return DeviceType::MAIA;
}
Device exchangeDevice(Device d) const override {
AT_ASSERT(d.type() == DeviceType::MAIA);
AT_ASSERT(d.index() == 0);
return d;
}
Device getDevice() const override {
return Device(DeviceType::MAIA, 0);
}
void setDevice(Device d) const override {
AT_ASSERT(d.type() == DeviceType::MAIA);
AT_ASSERT(d.index() == 0);
}
void uncheckedSetDevice(Device d) const noexcept override {
}
Stream getStream(Device d) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0));
}
Stream exchangeStream(Stream s) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0));
}
DeviceIndex deviceCount() const noexcept override {
return 1;
}
// Event-related functions
void record(void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(false, "MAIA backend doesn't support events.");
}
void block(
void* event,
const Stream& stream) const override {
TORCH_CHECK(false, "MAIA backend doesn't support events.");
}
bool queryEvent(void* event) const override {
TORCH_CHECK(false, "MAIA backend doesn't support events.");
}
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override { }
};
constexpr DeviceType MAIAGuardImpl::static_type;
C10_REGISTER_GUARD_IMPL(MAIA, MAIAGuardImpl);
int get_test_int() {
return test_int;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_test_int", &get_test_int);
}
|