File: aoti_custom_class.cpp

package info (click to toggle)
pytorch 2.6.0%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 161,672 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (55 lines) | stat: -rw-r--r-- 1,780 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <stdexcept>

#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#if defined(USE_CUDA) || defined(USE_ROCM)
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif

#include "aoti_custom_class.h"

namespace torch::aot_inductor {

static auto registerMyAOTIClass =
    torch::class_<MyAOTIClass>("aoti", "MyAOTIClass")
        .def(torch::init<std::string, std::string>())
        .def("forward", &MyAOTIClass::forward)
        .def_pickle(
            [](const c10::intrusive_ptr<MyAOTIClass>& self)
                -> std::vector<std::string> {
              std::vector<std::string> v;
              v.push_back(self->lib_path());
              v.push_back(self->device());
              return v;
            },
            [](std::vector<std::string> params) {
              return c10::make_intrusive<MyAOTIClass>(params[0], params[1]);
            });

MyAOTIClass::MyAOTIClass(
    const std::string& model_path,
    const std::string& device)
    : lib_path_(model_path), device_(device) {
  if (device_ == "cpu") {
    runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
        model_path.c_str());
#if defined(USE_CUDA) || defined(USE_ROCM)
  } else if (device_ == "cuda") {
    runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
        model_path.c_str());
#endif
#if defined(USE_XPU)
  } else if (device_ == "xpu") {
    runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerXpu>(
        model_path.c_str());
#endif
  } else {
    throw std::runtime_error("invalid device: " + device);
  }
}

std::vector<torch::Tensor> MyAOTIClass::forward(
    std::vector<torch::Tensor> inputs) {
  return runner_->run(inputs);
}

} // namespace torch::aot_inductor