1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
|
#pragma once
#include <c10/core/Allocator.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <mutex>
namespace torch::cuda::CUDAPluggableAllocator {
using MallocFuncType = void*(size_t, int, cudaStream_t);
using FreeFuncType = void(void*, size_t, int, cudaStream_t);
// A CUDAPluggableAllocatorDeleterContext object is used as the `ctx`
// argument for DataPtr. We need context because a user can use
// multiple allocators in the same PyTorch program, and
// the allocators can have different free functions, such as:
// free, cudaFree, cudaFreeAsync, ncclMemFree etc.
struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext {
explicit CUDAPluggableAllocatorDeleterContext(
std::function<FreeFuncType> free_fn,
void* data,
size_t size,
int device,
cudaStream_t stream);
void free();
private:
std::function<FreeFuncType> free_fn_;
void* data_;
size_t size_;
int device_;
cudaStream_t stream_;
};
#if defined(TORCH_HIP_VERSION)
using streamType = c10::hip::HIPStream;
#else
using streamType = c10::cuda::CUDAStream;
#endif
TORCH_CUDA_CPP_API std::shared_ptr<
c10::cuda::CUDACachingAllocator::CUDAAllocator>
getCurrentAllocator();
TORCH_CUDA_CPP_API std::shared_ptr<
c10::cuda::CUDACachingAllocator::CUDAAllocator>
createCustomAllocator(
std::function<MallocFuncType> alloc_fn,
std::function<FreeFuncType> free_fn);
TORCH_CUDA_CPP_API void changeCurrentAllocator(
const std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>&
allocator);
struct _AllocationMetadata {
_AllocationMetadata();
_AllocationMetadata(
size_t size,
c10::DeviceIndex device_idx,
cudaStream_t stream);
size_t size;
c10::DeviceIndex device_idx;
cudaStream_t stream;
};
struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
: public c10::cuda::CUDACachingAllocator::CUDAAllocator {
CUDAPluggableAllocator(
std::function<MallocFuncType> alloc_fn,
std::function<FreeFuncType> free_fn);
CUDAPluggableAllocator(CUDAPluggableAllocator& other);
CUDAPluggableAllocator(CUDAPluggableAllocator&& other) = delete;
CUDAPluggableAllocator& operator=(const CUDAPluggableAllocator& other) =
delete;
CUDAPluggableAllocator& operator=(CUDAPluggableAllocator&& other) = delete;
~CUDAPluggableAllocator() override = default;
void set_init_fn(std::function<void(int)> init_fn);
void set_reset_fn(std::function<void()> reset_fn);
void set_memory_fraction_fn(
std::function<void(double, int)> memory_fraction_fn);
void set_base_alloc_fn(std::function<void*(void*, size_t*)> base_alloc_fn);
void set_record_stream_fn(
std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn);
void set_begin_allocate_to_pool(
std::function<
void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)>
capture_begin_fn);
void set_end_allocate_to_pool_fn(
std::function<void(int, c10::cuda::MempoolId_t)> capture_about_to_end_fn);
void set_release_pool(
std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn);
void* malloc(size_t size, c10::DeviceIndex device, cudaStream_t stream);
c10::DataPtr allocate(size_t size) override;
c10::DeleterFnPtr raw_deleter() const override;
void* raw_alloc(size_t nbytes) override;
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override;
void raw_delete(void* ptr) override;
void init(int device_count) override;
bool initialized() override;
double getMemoryFraction(c10::DeviceIndex device) override;
void setMemoryFraction(double fraction, c10::DeviceIndex device) override;
void emptyCache() override;
void enable(bool) override {}
bool isEnabled() const override {
return true;
}
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override;
void* getBaseAllocation(void* ptr, size_t* size) override;
void recordStream(const c10::DataPtr&, streamType stream) override;
c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) override;
void resetAccumulatedStats(c10::DeviceIndex device) override;
void resetPeakStats(c10::DeviceIndex device) override;
c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override;
void beginAllocateToPool(
c10::DeviceIndex device,
c10::cuda::MempoolId_t mempool_id,
std::function<bool(cudaStream_t)>) override;
void endAllocateToPool(
c10::DeviceIndex device,
c10::cuda::MempoolId_t mempool_id) override;
void releasePool(c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id)
override;
std::shared_ptr<void> getIpcDevPtr(std::string handle) override;
c10::cuda::CUDACachingAllocator::ShareableHandle shareIpcHandle(
void*) override;
void recordHistory(
bool enabled,
c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
c10::cuda::CUDACachingAllocator::RecordContext when) override;
void attachOutOfMemoryObserver(
c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) override;
void attachAllocatorTraceTracker(
c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) override;
std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState>
getCheckpointState(c10::DeviceIndex device, at::cuda::MempoolId_t id)
override;
c10::cuda::CUDACachingAllocator::CheckpointDelta setCheckpointPoolState(
c10::DeviceIndex device,
std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps)
override;
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access)
override;
cudaError_t memcpyAsync(
void* dst,
int dstDevice,
const void* src,
int srcDevice,
size_t count,
cudaStream_t stream,
bool p2p_enabled) override;
std::string name() override;
void copy_data(void* dest, const void* src, std::size_t count) const final;
protected:
std::function<MallocFuncType> alloc_fn_;
std::function<FreeFuncType> free_fn_;
std::function<void(int)> init_fn_;
std::function<void()> reset_fn_;
std::function<void(double, int)> memory_fraction_fn_;
std::function<void*(void*, size_t*)> base_alloc_fn_;
std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn_;
std::function<
void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)>
begin_allocate_to_pool_fn_;
std::function<void(int, c10::cuda::MempoolId_t)> end_allocate_to_pool_fn_;
std::function<void(int, c10::cuda::MempoolId_t)> relase_pool_fn_;
std::mutex allocator_mutex_;
// We do the bookeeping here in order to simplify custom allocators
std::unordered_map<void*, _AllocationMetadata> allocation_metadata_;
bool initialized_ = false;
};
} // namespace torch::cuda::CUDAPluggableAllocator
|