1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
|
#include <c10/core/Allocator.h>
#include <array>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <cstring>
namespace c10 {
DataPtr Allocator::clone(const void* data, std::size_t n) {
DataPtr new_data = allocate(n);
copy_data(new_data.mutable_get(), data, n);
return new_data;
}
void Allocator::default_copy_data(
void* dest,
const void* src,
std::size_t count) const {
std::memcpy(dest, src, count);
}
bool Allocator::is_simple_data_ptr(const DataPtr& data_ptr) const {
return data_ptr.get() == data_ptr.get_context();
}
static void deleteInefficientStdFunctionContext(void* ptr) {
delete static_cast<InefficientStdFunctionContext*>(ptr);
}
at::DataPtr InefficientStdFunctionContext::makeDataPtr(
void* ptr,
std::function<void(void*)> deleter,
Device device) {
return {
ptr,
new InefficientStdFunctionContext(ptr, std::move(deleter)),
&deleteInefficientStdFunctionContext,
device};
}
static std::array<at::Allocator*, at::COMPILE_TIME_MAX_DEVICE_TYPES>
allocator_array{};
static std::array<uint8_t, at::COMPILE_TIME_MAX_DEVICE_TYPES>
allocator_priority{};
void SetAllocator(at::DeviceType t, at::Allocator* alloc, uint8_t priority) {
if (priority >= allocator_priority[static_cast<int>(t)]) {
allocator_array[static_cast<int>(t)] = alloc;
allocator_priority[static_cast<int>(t)] = priority;
}
}
at::Allocator* GetAllocator(const at::DeviceType& t) {
auto* alloc = allocator_array[static_cast<int>(t)];
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alloc, "Allocator for ", t, " is not set.");
return alloc;
}
bool memoryProfilingEnabled() {
auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(
ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE));
return reporter_ptr && reporter_ptr->memoryProfilingEnabled();
}
void reportMemoryUsageToProfiler(
void* ptr,
int64_t alloc_size,
size_t total_allocated,
size_t total_reserved,
Device device) {
auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(
ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE));
if (reporter_ptr) {
reporter_ptr->reportMemoryUsage(
ptr, alloc_size, total_allocated, total_reserved, device);
}
}
void reportOutOfMemoryToProfiler(
int64_t alloc_size,
size_t total_allocated,
size_t total_reserved,
Device device) {
auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(
ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE));
if (reporter_ptr) {
reporter_ptr->reportOutOfMemory(
alloc_size, total_allocated, total_reserved, device);
}
}
void MemoryReportingInfoBase::reportOutOfMemory(
int64_t /*alloc_size*/,
size_t /*total_allocated*/,
size_t /*total_reserved*/,
Device /*device*/) {}
} // namespace c10
|