File: cuda_managed_ptr.hpp

package info (click to toggle)
scipy 1.16.0-1exp7
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 234,820 kB
  • sloc: cpp: 503,145; python: 344,611; ansic: 195,638; javascript: 89,566; fortran: 56,210; cs: 3,081; f90: 1,150; sh: 848; makefile: 785; pascal: 284; csh: 135; lisp: 134; xml: 56; perl: 51
file content (139 lines) | stat: -rw-r--r-- 3,661 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

//  Copyright John Maddock 2016.
//  Use, modification and distribution are subject to the
//  Boost Software License, Version 1.0. (See accompanying file
//  LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#ifndef BOOST_MATH_CUDA_MANAGED_PTR_HPP
#define BOOST_MATH_CUDA_MANAGED_PTR_HPP

#ifdef _MSC_VER
#pragma once
#endif

#include <cuda_runtime.h>

class managed_holder_base
{
protected:
   static int count;
   managed_holder_base() { ++count; }
   ~managed_holder_base()
   {
      if(0 == --count)
         cudaDeviceSynchronize();
   }
};

int managed_holder_base::count = 0;

//
// Reset the device and exit:
// cudaDeviceReset causes the driver to clean up all state. While
// not mandatory in normal operation, it is good practice.  It is also
// needed to ensure correct operation when the application is being
// profiled. Calling cudaDeviceReset causes all profile data to be
// flushed before the application exits.
//
// We have a global instance of this class, plus instances for each
// managed pointer.  Last one out the door switches the lights off.
//
class cudaResetter
{
   static int count;
public:
   cudaResetter() { ++count;  }
   ~cudaResetter()
   {
      if(--count == 0)
      {
         cudaError_t err = cudaDeviceReset();
         if(err != cudaSuccess)
         {
            std::cerr << "Failed to deinitialize the device! error=" << cudaGetErrorString(err) << std::endl;
         }
      }
   }
};

int cudaResetter::count = 0;

cudaResetter global_resetter;

template <class T>
class cuda_managed_ptr
{
   T* data;
   static const cudaResetter resetter;
   cuda_managed_ptr(const cuda_managed_ptr&) = delete;
   cuda_managed_ptr& operator=(cuda_managed_ptr const&) = delete;
   void free()
   {
      if(data)
      {
         cudaDeviceSynchronize();
         cudaError_t err = cudaFree(data);
         if(err != cudaSuccess)
         {
            std::cerr << "Failed to deinitialize the device! error=" << cudaGetErrorString(err) << std::endl;
         }
      }
   }
public:
   cuda_managed_ptr() : data(0) {}
   cuda_managed_ptr(std::size_t n)
   {
      cudaError_t err = cudaSuccess;
      void *ptr;
      err = cudaMallocManaged(&ptr, n * sizeof(T));
      if(err != cudaSuccess)
         throw std::runtime_error(cudaGetErrorString(err));
      cudaDeviceSynchronize();
      data = static_cast<T*>(ptr);
   }
   cuda_managed_ptr(cuda_managed_ptr&& o)
   {
      data = o.data;
      o.data = 0;
   }
   cuda_managed_ptr& operator=(cuda_managed_ptr&& o)
   {
      free();
      data = o.data;
      o.data = 0;
      return *this;
   }
   ~cuda_managed_ptr()
   {
      free();
   }

   class managed_holder : managed_holder_base
   {
      T* pdata;
   public:
      managed_holder(T* p) : managed_holder_base(), pdata(p) {}
      managed_holder(const managed_holder& o) : managed_holder_base(), pdata(o.pdata) {}
      operator T* () { return pdata; }
      T& operator[] (std::size_t n) { return pdata[n]; }
   };
   class const_managed_holder : managed_holder_base
   {
      const T* pdata;
   public:
      const_managed_holder(T* p) : managed_holder_base(), pdata(p) {}
      const_managed_holder(const managed_holder& o) : managed_holder_base(), pdata(o.pdata) {}
      operator const T* () { return pdata; }
      const T& operator[] (std::size_t n) { return pdata[n]; }
   };

   managed_holder get() { return managed_holder(data); }
   const_managed_holder get()const { return data; }
   T& operator[](std::size_t n) { return data[n]; }
   const T& operator[](std::size_t n)const { return data[n]; }
};

template <class T>
cudaResetter const cuda_managed_ptr<T>::resetter;

#endif