1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
#pragma once
#include <ATen/ATen.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/codegen/fuser/fused_kernel.h>
#include <torch/csrc/utils/disallow_copy.h>
#include <cstdint>
#include <memory>
#include <string>
// Forward declare DynamicLibrary
namespace at {
struct DynamicLibrary;
}
namespace torch {
namespace jit {
namespace fuser {
namespace cpu {
// Represents a compiled CPU kernel and the metadata necessary to run it
struct TORCH_API FusedKernelCPU : public FusedKernel {
FusedKernelCPU(
std::string name,
std::string code,
std::vector<TensorDesc> input_desc,
std::vector<TensorDesc> output_desc,
std::vector<PartitionDesc> chunk_desc,
std::vector<PartitionDesc> concat_desc,
bool has_random);
at::Backend backend() const override {
return at::Backend::CPU;
}
void launch_raw(const uint32_t numel, std::vector<void*>& arguments)
const override {
kernel(numel, arguments.data());
}
private:
std::unique_ptr<at::DynamicLibrary> so_lib;
void (*kernel)(uint32_t, void**) = nullptr;
};
} // namespace cpu
} // namespace fuser
} // namespace jit
} // namespace torch
|