1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
|
#pragma once
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
namespace torch {
namespace jit {
namespace mobile {
namespace nnc {
using nnc_kernel_function_type = int(void**);
struct TORCH_API NNCKernel {
virtual ~NNCKernel() = default;
virtual int execute(void** /* args */) = 0;
};
C10_DECLARE_REGISTRY(NNCKernelRegistry, NNCKernel);
#define REGISTER_NNC_KERNEL(id, kernel, ...) \
extern "C" { \
nnc_kernel_function_type kernel; \
} \
struct NNCKernel_##kernel : public NNCKernel { \
int execute(void** args) override { \
return kernel(args); \
} \
}; \
C10_REGISTER_TYPED_CLASS(NNCKernelRegistry, id, NNCKernel_##kernel);
namespace registry {
inline bool has_nnc_kernel(const std::string& id) {
return NNCKernelRegistry()->Has(id);
}
inline std::unique_ptr<NNCKernel> get_nnc_kernel(const std::string& id) {
return NNCKernelRegistry()->Create(id);
}
} // namespace registry
} // namespace nnc
} // namespace mobile
} // namespace jit
} // namespace torch
|