1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
#include <ATen/ThreadLocalState.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/record_function.h>
#include <torch/csrc/autograd/record_function_ops.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/library.h>
namespace caffe2 {
// Required for cpp_custom_type_hack to work
// NOLINTNEXTLINE(bugprone-exception-escape)
CAFFE_KNOWN_TYPE(at::RecordFunction);
} // namespace caffe2
namespace torch {
namespace autograd {
namespace profiler {
// Creates a new profiling scope using RecordFunction and invokes its starting
// callbacks.
void record_function_enter(
const std::string& name,
const c10::optional<std::string>& args,
at::RecordFunction& rec) {
if (rec.isActive()) {
if (rec.needsInputs() && args.has_value()) {
rec.before(
name, c10::ArrayRef<const c10::IValue>{c10::IValue{args.value()}});
} else {
rec.before(name);
}
}
}
// Legacy signature using cpp_custom_type_hack
at::Tensor record_function_enter_legacy(
const std::string& name,
const c10::optional<std::string>& args) {
auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
record_function_enter(name, args, *rec);
return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions());
}
// New signature using custom_class
c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
const std::string& name,
const c10::optional<std::string>& args) {
auto rec =
c10::make_intrusive<PythonRecordFunction>(at::RecordScope::USER_SCOPE);
record_function_enter(name, args, rec->record);
return rec;
}
at::RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle) {
auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
return rec;
}
// Ends the profiling scope created with record_function_enter.
void record_function_exit(at::RecordFunction& rec) {
rec.end();
}
// Legacy signature using cpp_custom_type_hack
void record_function_exit_legacy(const at::Tensor& handle) {
// We don't actually need to do anything with handle just need to persist the
// lifetime until now.
auto& rec = getRecordFunctionFromTensor(handle);
record_function_exit(rec);
}
// New signature using custom_class
void record_function_exit_new(
const c10::intrusive_ptr<PythonRecordFunction>& record) {
record_function_exit(record->record);
}
template <typename Func>
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
Func get_record,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
// Profiling callback that ends the associated record_function
// and returns the value of the passed in future.
std::function<c10::IValue(c10::ivalue::Future&)> futureProfilingFunc =
[get_record = std::move(get_record)](c10::ivalue::Future& fut) {
auto& rec = get_record();
rec.end();
// Note: this future is returned to the user to ensure that a call to
// wait() ensures that profiling callbacks have ran. To ensure that this
// is transparent, we must make this future propagate the value of the
// RPC future. Use value() here instead of constValue() to ensure we
// propagate errors.
return fut.value();
};
// Define a future that completes after the profiling callbacks are run.
auto profiledFut = fut->then(
at::wrapPropagateTLSState(std::move(futureProfilingFunc)),
fut->elementType());
return profiledFut;
}
// Legacy signature using cpp_custom_type_hack
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy(
const at::Tensor& handle,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
return _call_end_callbacks_on_fut(
[handle]() -> at::RecordFunction& {
TORCH_INTERNAL_ASSERT(
handle.defined(),
"Undefined RecordFunction handle. This can happen if the handle is "
"not correctly persisted and is destroyed before the future is "
"realized.");
return getRecordFunctionFromTensor(handle);
},
fut);
}
// New signature using custom_class
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_new(
const c10::intrusive_ptr<PythonRecordFunction>& record,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
return _call_end_callbacks_on_fut(
[record]() -> at::RecordFunction& { return record->record; }, fut);
}
// Internal only, do not use directly, use Python's record_function()
TORCH_LIBRARY_FRAGMENT(profiler, m) {
m.class_<PythonRecordFunction>("_RecordFunction");
m.def(
"_record_function_enter(str name, str? args=None) -> Tensor",
&record_function_enter_legacy);
m.def(
"_record_function_enter_new(str name, str? args=None) -> "
"__torch__.torch.classes.profiler._RecordFunction",
&record_function_enter_new);
m.def("_record_function_exit", &record_function_exit_legacy);
m.def("_record_function_exit._RecordFunction", &record_function_exit_new);
torch::jit::registerOperator(torch::jit::Operator(
"profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
[](jit::Stack& stack) {
// Pop inputs, which should be a future and a tensor
auto fut = jit::pop(stack).toFuture();
auto tensor = jit::pop(stack).toTensor();
auto profiledFut = _call_end_callbacks_on_fut_legacy(tensor, fut);
// return future that completes when profiling callbacks have run.
jit::push(stack, std::move(profiledFut));
},
c10::AliasAnalysisKind::FROM_SCHEMA));
torch::jit::registerOperator(torch::jit::Operator(
"profiler::_call_end_callbacks_on_jit_fut._RecordFunction("
"__torch__.torch.classes.profiler._RecordFunction x, Future(t) y) -> Future(t)",
[](c10::Stack& stack) {
// Pop inputs, which should be a future and a PythonRecordFunction
auto fut = torch::jit::pop(stack).toFuture();
auto tensor =
torch::jit::pop(stack).toCustomClass<PythonRecordFunction>();
auto profiledFut = _call_end_callbacks_on_fut_new(tensor, fut);
// return future that completes when profiling callbacks have run.
torch::jit::push(stack, std::move(profiledFut));
},
c10::AliasAnalysisKind::FROM_SCHEMA));
}
} // namespace profiler
} // namespace autograd
} // namespace torch
|