1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
|
#pragma once
#include <thrust/detail/static_assert.h>
#include <string>
#undef THRUST_STATIC_ASSERT
#undef THRUST_STATIC_ASSERT_MSG
#define THRUST_STATIC_ASSERT(B) unittest::assert_static((B), __FILE__, __LINE__);
#define THRUST_STATIC_ASSERT_MSG(B, msg) unittest::assert_static((B), __FILE__, __LINE__);
namespace unittest
{
_CCCL_HOST_DEVICE void assert_static(bool condition, const char* filename, int lineno);
}
#include <thrust/device_delete.h>
#include <thrust/device_new.h>
#include <nv/target>
#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
# define ASSERT_STATIC_ASSERT(X) \
{ \
bool triggered = false; \
typedef unittest::static_assert_exception ex_t; \
thrust::device_ptr<ex_t> device_ptr = thrust::device_new<ex_t>(); \
ex_t* raw_ptr = thrust::raw_pointer_cast(device_ptr); \
::cudaMemcpyToSymbol(unittest::detail::device_exception, &raw_ptr, sizeof(ex_t*)); \
try \
{ \
X; \
} \
catch (ex_t) \
{ \
triggered = true; \
} \
if (!triggered) \
{ \
triggered = static_cast<ex_t>(*device_ptr).triggered; \
} \
thrust::device_free(device_ptr); \
raw_ptr = NULL; \
::cudaMemcpyToSymbol(unittest::detail::device_exception, &raw_ptr, sizeof(ex_t*)); \
if (!triggered) \
{ \
unittest::UnitTestFailure f; \
f << "[" << __FILE__ << ":" << __LINE__ << "] did not trigger a THRUST_STATIC_ASSERT"; \
throw f; \
} \
}
#else
# define ASSERT_STATIC_ASSERT(X) \
{ \
bool triggered = false; \
typedef unittest::static_assert_exception ex_t; \
try \
{ \
X; \
} \
catch (ex_t) \
{ \
triggered = true; \
} \
if (!triggered) \
{ \
unittest::UnitTestFailure f; \
f << "[" << __FILE__ << ":" << __LINE__ << "] did not trigger a THRUST_STATIC_ASSERT"; \
throw f; \
} \
}
#endif
namespace unittest
{
class static_assert_exception
{
public:
_CCCL_HOST_DEVICE static_assert_exception()
: triggered(false)
{}
_CCCL_HOST_DEVICE static_assert_exception(const char* filename, int lineno)
: triggered(true)
, filename(filename)
, lineno(lineno)
{}
bool triggered;
const char* filename;
int lineno;
};
namespace detail
{
#if defined(_CCCL_COMPILER_GCC) || defined(_CCCL_COMPILER_CLANG)
__attribute__((used))
#endif
_CCCL_DEVICE static static_assert_exception* device_exception = NULL;
} // namespace detail
_CCCL_HOST_DEVICE void assert_static(bool condition, const char* filename, int lineno)
{
if (!condition)
{
static_assert_exception ex(filename, lineno);
NV_IF_TARGET(NV_IS_DEVICE, (*detail::device_exception = ex;), (throw ex;));
}
}
} // namespace unittest
|