1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
|
#include <ATen/core/ivalue.h>
#include <torch/csrc/utils/init.h>
#include <torch/csrc/utils/throughput_benchmark.h>
#include <pybind11/functional.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace throughput_benchmark {
void initThroughputBenchmarkBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
using namespace torch::throughput_benchmark;
py::class_<BenchmarkConfig>(m, "BenchmarkConfig")
.def(py::init<>())
.def_readwrite(
"num_calling_threads", &BenchmarkConfig::num_calling_threads)
.def_readwrite("num_worker_threads", &BenchmarkConfig::num_worker_threads)
.def_readwrite("num_warmup_iters", &BenchmarkConfig::num_warmup_iters)
.def_readwrite("num_iters", &BenchmarkConfig::num_iters)
.def_readwrite(
"profiler_output_path", &BenchmarkConfig::profiler_output_path);
py::class_<BenchmarkExecutionStats>(m, "BenchmarkExecutionStats")
.def_readonly("latency_avg_ms", &BenchmarkExecutionStats::latency_avg_ms)
.def_readonly("num_iters", &BenchmarkExecutionStats::num_iters);
py::class_<ThroughputBenchmark>(m, "ThroughputBenchmark", py::dynamic_attr())
.def(py::init<jit::Module>())
.def(py::init<py::object>())
.def(
"add_input",
[](ThroughputBenchmark& self, py::args args, py::kwargs kwargs) {
self.addInput(std::move(args), std::move(kwargs));
})
.def(
"run_once",
[](ThroughputBenchmark& self, py::args args, py::kwargs kwargs) {
// Depending on this being ScriptModule of nn.Module we will release
// the GIL or not further down in the stack
return self.runOnce(std::move(args), std::move(kwargs));
})
.def("benchmark", [](ThroughputBenchmark& self, BenchmarkConfig config) {
// The benchmark always runs without the GIL. GIL will be used where
// needed. This will happen only in the nn.Module mode when manipulating
// inputs and running actual inference
pybind11::gil_scoped_release no_gil_guard;
return self.benchmark(config);
});
}
} // namespace throughput_benchmark
} // namespace torch
|