1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
#pragma once
#include <ATen/ATen.h>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
namespace c10d::symmetric_memory {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
using HandleType = CUmemGenericAllocationHandle;
#else
using HandleType = void*;
#endif
// Resource wrapper that owns a (vaddr, allocation handle) pair. Upon
// destruction, it unmaps the vaddr and releases the allocation handle.
struct AllocationRef : public c10::intrusive_ptr_target {
void* ptr;
HandleType handle;
size_t block_size;
int device_idx;
AllocationRef(
void* ptr,
HandleType handle,
size_t block_size,
int device_idx);
~AllocationRef();
};
class CUDASymmetricMemory : public SymmetricMemory {
public:
CUDASymmetricMemory(
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs,
std::vector<void*> buffers,
std::vector<void*> signal_pads,
HandleType mc_handle,
void* mc_addr,
size_t buffer_size,
int local_device_idx,
int rank,
int world_size);
~CUDASymmetricMemory() override{};
std::vector<void*> get_buffer_ptrs() override;
std::vector<void*> get_signal_pad_ptrs() override;
void** get_buffer_ptrs_dev() override;
void** get_signal_pad_ptrs_dev() override;
size_t get_buffer_size() override;
size_t get_signal_pad_size() override;
bool has_multicast_support() override;
void* get_multicast_ptr() override;
at::Tensor get_buffer(
int rank,
c10::IntArrayRef sizes,
c10::ScalarType dtype,
int64_t storage_offset) override;
at::Tensor get_signal_pad(
int rank,
c10::IntArrayRef sizes,
std::optional<c10::ScalarType> dtype,
int64_t storage_offset) override;
void barrier(int channel, size_t timeout_ms) override;
void put_signal(int dst_rank, int channel, size_t timeout_ms) override;
void wait_signal(int src_rank, int channel, size_t timeout_ms) override;
int get_rank() override;
int get_world_size() override;
private:
std::vector<c10::intrusive_ptr<AllocationRef>> alloc_refs_;
std::vector<void*> buffers_;
std::vector<void*> signal_pads_;
HandleType mc_handle_;
void* mc_addr_;
size_t buffer_size_;
int local_device_idx_;
int rank_;
int world_size_;
void** buffers_dev_;
void** signal_pads_dev_;
};
// Metadata associated with each allocation performed by
// `CUDASymmetricMemoryAllocator`.
struct Block : public c10::intrusive_ptr_target {
c10::intrusive_ptr<AllocationRef> alloc_ref;
int device_idx;
size_t block_size;
size_t buffer_size;
size_t signal_pad_offset;
std::optional<std::string> default_group_name;
std::map<std::string, c10::intrusive_ptr<CUDASymmetricMemory>> symm_mems;
Block(
c10::intrusive_ptr<AllocationRef> alloc_ref,
int device_idx,
size_t block_size,
size_t buffer_size,
size_t signal_pad_offset,
const std::optional<std::string>& group_name);
};
class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
public:
void* alloc(
size_t size,
int device_idx,
const std::optional<std::string>& group_name) override;
void free(void* ptr) override;
size_t get_alloc_size(void* ptr) override;
c10::intrusive_ptr<SymmetricMemory> rendezvous(
void* ptr,
const std::optional<std::string>& group_name) override;
bool has_multicast_support(int device_idx) override;
private:
c10::intrusive_ptr<Block> find_block(void* ptr);
std::shared_mutex mutex_;
std::unordered_map<void*, c10::intrusive_ptr<Block>> ptr_to_block_;
};
} // namespace c10d::symmetric_memory
|