1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
// Copyright John Maddock 2016.
// Use, modification and distribution are subject to the
// Boost Software License, Version 1.0. (See accompanying file
// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
#ifndef BOOST_MATH_CUDA_MANAGED_PTR_HPP
#define BOOST_MATH_CUDA_MANAGED_PTR_HPP
#ifdef _MSC_VER
#pragma once
#endif
#include <cuda_runtime.h>
class managed_holder_base
{
protected:
static int count;
managed_holder_base() { ++count; }
~managed_holder_base()
{
if(0 == --count)
cudaDeviceSynchronize();
}
};
int managed_holder_base::count = 0;
//
// Reset the device and exit:
// cudaDeviceReset causes the driver to clean up all state. While
// not mandatory in normal operation, it is good practice. It is also
// needed to ensure correct operation when the application is being
// profiled. Calling cudaDeviceReset causes all profile data to be
// flushed before the application exits.
//
// We have a global instance of this class, plus instances for each
// managed pointer. Last one out the door switches the lights off.
//
class cudaResetter
{
static int count;
public:
cudaResetter() { ++count; }
~cudaResetter()
{
if(--count == 0)
{
cudaError_t err = cudaDeviceReset();
if(err != cudaSuccess)
{
std::cerr << "Failed to deinitialize the device! error=" << cudaGetErrorString(err) << std::endl;
}
}
}
};
int cudaResetter::count = 0;
cudaResetter global_resetter;
template <class T>
class cuda_managed_ptr
{
T* data;
static const cudaResetter resetter;
cuda_managed_ptr(const cuda_managed_ptr&) = delete;
cuda_managed_ptr& operator=(cuda_managed_ptr const&) = delete;
void free()
{
if(data)
{
cudaDeviceSynchronize();
cudaError_t err = cudaFree(data);
if(err != cudaSuccess)
{
std::cerr << "Failed to deinitialize the device! error=" << cudaGetErrorString(err) << std::endl;
}
}
}
public:
cuda_managed_ptr() : data(0) {}
cuda_managed_ptr(std::size_t n)
{
cudaError_t err = cudaSuccess;
void *ptr;
err = cudaMallocManaged(&ptr, n * sizeof(T));
if(err != cudaSuccess)
throw std::runtime_error(cudaGetErrorString(err));
cudaDeviceSynchronize();
data = static_cast<T*>(ptr);
}
cuda_managed_ptr(cuda_managed_ptr&& o)
{
data = o.data;
o.data = 0;
}
cuda_managed_ptr& operator=(cuda_managed_ptr&& o)
{
free();
data = o.data;
o.data = 0;
return *this;
}
~cuda_managed_ptr()
{
free();
}
class managed_holder : managed_holder_base
{
T* pdata;
public:
managed_holder(T* p) : managed_holder_base(), pdata(p) {}
managed_holder(const managed_holder& o) : managed_holder_base(), pdata(o.pdata) {}
operator T* () { return pdata; }
T& operator[] (std::size_t n) { return pdata[n]; }
};
class const_managed_holder : managed_holder_base
{
const T* pdata;
public:
const_managed_holder(T* p) : managed_holder_base(), pdata(p) {}
const_managed_holder(const managed_holder& o) : managed_holder_base(), pdata(o.pdata) {}
operator const T* () { return pdata; }
const T& operator[] (std::size_t n) { return pdata[n]; }
};
managed_holder get() { return managed_holder(data); }
const_managed_holder get()const { return data; }
T& operator[](std::size_t n) { return data[n]; }
const T& operator[](std::size_t n)const { return data[n]; }
};
template <class T>
cudaResetter const cuda_managed_ptr<T>::resetter;
#endif
|