1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
|
#pragma once
#include <memory>
#include <torch/torch.h>
namespace torch::inductor {
class AOTIModelContainerRunner;
} // namespace torch::inductor
namespace torch::aot_inductor {
class MyAOTIClass : public torch::CustomClassHolder {
public:
explicit MyAOTIClass(
const std::string& model_path,
const std::string& device = "cuda");
~MyAOTIClass() {}
MyAOTIClass(const MyAOTIClass&) = delete;
MyAOTIClass& operator=(const MyAOTIClass&) = delete;
MyAOTIClass& operator=(MyAOTIClass&&) = delete;
const std::string& lib_path() const {
return lib_path_;
}
const std::string& device() const {
return device_;
}
std::vector<torch::Tensor> forward(std::vector<torch::Tensor> inputs);
private:
const std::string lib_path_;
const std::string device_;
std::unique_ptr<torch::inductor::AOTIModelContainerRunner> runner_;
};
} // namespace torch::aot_inductor
|