1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
|
#include <c10/core/Allocator.h>
#include <c10/util/ThreadLocalDebugInfo.h>
namespace c10 {
static void deleteInefficientStdFunctionContext(void* ptr) {
delete static_cast<InefficientStdFunctionContext*>(ptr);
}
at::DataPtr InefficientStdFunctionContext::makeDataPtr(
void* ptr,
const std::function<void(void*)>& deleter,
Device device) {
return {ptr,
new InefficientStdFunctionContext({ptr, deleter}),
&deleteInefficientStdFunctionContext,
device};
}
C10_API at::Allocator* allocator_array[at::COMPILE_TIME_MAX_DEVICE_TYPES];
C10_API uint8_t allocator_priority[at::COMPILE_TIME_MAX_DEVICE_TYPES] = {0};
void SetAllocator(at::DeviceType t, at::Allocator* alloc, uint8_t priority) {
if (priority >= allocator_priority[static_cast<int>(t)]) {
allocator_array[static_cast<int>(t)] = alloc;
allocator_priority[static_cast<int>(t)] = priority;
}
}
at::Allocator* GetAllocator(const at::DeviceType& t) {
auto* alloc = allocator_array[static_cast<int>(t)];
AT_ASSERTM(alloc, "Allocator for ", t, " is not set.");
return alloc;
}
bool memoryProfilingEnabled() {
const auto& state = ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE);
auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(state.get());
return reporter_ptr && reporter_ptr->memoryProfilingEnabled();
}
void reportMemoryUsageToProfiler(void* ptr, int64_t alloc_size, Device device) {
const auto& state = ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE);
auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(state.get());
if (reporter_ptr) {
reporter_ptr->reportMemoryUsage(ptr, alloc_size, device);
}
}
MemoryReportingInfoBase::MemoryReportingInfoBase() {}
} // namespace c10
|