| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 
 | #if !defined(C10_MOBILE) && !defined(ANDROID)
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/function_schema.h>
#include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/inductor/aoti_eager/kernel_meta_info.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
#include <torch/csrc/utils/pybind.h>
#include <string>
namespace torch::inductor {
// Represent AOTI kernel. It contains all the parameter metadata of the kernel
// and the AOTI model runner.
struct AOTIKernelMetadata {
  // Represent all the parameters of AOTI kernel
  std::vector<ParameterMetadata> parameter_metadata_list_;
  // AOTI model runner to run the AOTI kernel
  std::shared_ptr<AOTIModelContainerRunner> kernel_runner_;
  AOTIKernelMetadata() : parameter_metadata_list_(), kernel_runner_(nullptr) {}
  // Check whether the given parameter metadata list is the same as the
  // parameter metadata list of the AOTI kernel.
  bool check(
      const std::vector<ParameterMetadata>& parameter_metadata_list) const {
    if (parameter_metadata_list_.size() != parameter_metadata_list.size()) {
      return false;
    }
    for (size_t i = 0; i < parameter_metadata_list_.size(); ++i) {
      if (parameter_metadata_list_[i] == parameter_metadata_list[i]) {
        continue;
      } else {
        return false;
      }
    }
    return true;
  }
};
// The AOTIPythonKernelHolder class uses the AOT Inductor to generate a kernel
// for a specified operation. To speed up this process, the generated kernel
// library is cached on disk. Detailed information from the input tensors is
// used as the key for caching the kernel library. On subsequent runs, these
// input tensors are used to search the cache. If a cache hit occurs, the cached
// kernel library is loaded and executed. If a cache miss occurs, the AOT
// Inductor is called again to generate the kernel library.
class AOTIPythonKernelHolder : public c10::OperatorKernel {
  // A DispatchKey object that represents the dispatch key for the kernel.
  c10::DispatchKey dispatch_key_;
  // Namespace of the kernel.
  std::string ns_;
  // Name of the operation the kernel performs.
  std::string op_name_with_overload_;
  // The device on which the kernel is to be executed.
  c10::Device device_;
  // The Python interpreter to get OpOverload object with the given op_name and
  // op_overload_name.
  c10::impl::PyInterpreter* pyinterpreter_;
  // Cache the produced kernels by AOTI and its metadata
  std::vector<AOTIKernelMetadata> aoti_kernel_cache_;
 public:
  AOTIPythonKernelHolder(
      c10::DispatchKey dispatch_key,
      std::string_view ns,
      std::string_view op_name_with_overload);
  void operator()(
      const c10::OperatorHandle& op,
      c10::DispatchKeySet keyset,
      torch::jit::Stack* stack);
 private:
  bool cache_lookup(
      const c10::OperatorHandle& op,
      const c10::DispatchKeySet& keyset,
      const torch::jit::Stack* stack,
      AOTIKernelMetadata& aoti_kernel_metadata);
  void cache_miss(
      const c10::OperatorHandle& op,
      const c10::DispatchKeySet& keyset,
      torch::jit::Stack* stack);
  void cache_hit(
      const AOTIKernelMetadata& aoti_kernel_metadata,
      const c10::OperatorHandle& op,
      const c10::DispatchKeySet& keyset,
      torch::jit::Stack* stack);
  // Invoke python utility function on the Inductor side to produce AOTI kernel
  // for the given operation.
  //   Inductor utility function -
  //   torch._inductor.utils.aoti_compile_with_persistent_cache
  std::string produce_aoti_kernel_lib(
      const c10::OperatorHandle& op,
      const c10::DispatchKeySet& keyset,
      const torch::jit::Stack* stack);
  // Invoke python utility function on the Inductor side to load AOTI kernel for
  // the given operation.
  //   Inductor utility function - torch._inductor.utils.load_aoti_eager_cache
  void init_aoti_kernel_cache();
  // Load the AOTIModelContainerRunner object from the given file path.
  std::shared_ptr<AOTIModelContainerRunner> load_aoti_model_runner(
      const std::string&);
};
} // namespace torch::inductor
#endif
 |