1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
|
#include <thrust/detail/seq.h>
#include <thrust/system/cpp/detail/par.h>
#include <thrust/system/omp/detail/par.h>
#include <thrust/system/tbb/detail/par.h>
#include <unittest/unittest.h>
#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
# include <thrust/system/cuda/detail/par.h>
#endif
template <typename T>
struct test_allocator_t
{};
test_allocator_t<int> test_allocator = test_allocator_t<int>();
const test_allocator_t<int> const_test_allocator = test_allocator_t<int>();
struct test_memory_resource_t final : thrust::mr::memory_resource<>
{
void* do_allocate(std::size_t size, std::size_t) override
{
return reinterpret_cast<void*>(size);
}
void do_deallocate(void* ptr, std::size_t size, std::size_t) override
{
ASSERT_EQUAL(ptr, reinterpret_cast<void*>(size));
}
} test_memory_resource;
template <typename Policy, template <typename> class CRTPBase>
struct policy_info
{
typedef Policy policy;
template <template <typename, template <typename> class> class Template, typename Argument>
struct apply_base_second
{
typedef Template<Argument, CRTPBase> type;
};
};
template <typename PolicyInfo>
struct TestAllocatorAttachment
{
template <typename Expected, typename T>
static void assert_correct(T)
{
ASSERT_EQUAL(
(::cuda::std::is_same<
T,
typename PolicyInfo::template apply_base_second<thrust::detail::execute_with_allocator, Expected>::type>::value),
true);
}
template <typename ExpectedResource, typename T>
static void assert_npa_correct(T)
{
ASSERT_EQUAL(
(::cuda::std::is_same<T,
typename PolicyInfo::template apply_base_second<
thrust::detail::execute_with_allocator,
thrust::mr::allocator<thrust::detail::max_align_t, ExpectedResource>>::type>::value),
true);
}
template <typename Policy>
void test_temporary_allocation_valid(Policy policy)
{
using thrust::detail::get_temporary_buffer;
return_temporary_buffer(policy, get_temporary_buffer<int>(policy, 123).first, 123);
}
void operator()()
{
typename PolicyInfo::policy policy;
// test correctness of attachment
assert_correct<test_allocator_t<int>>(policy(test_allocator_t<int>()));
assert_correct<test_allocator_t<int>&>(policy(test_allocator));
assert_correct<test_allocator_t<int>>(policy(const_test_allocator));
assert_npa_correct<test_memory_resource_t>(policy(&test_memory_resource));
// test whether the resulting policy is actually usable
// a real allocator is necessary here, unlike above
std::allocator<int> alloc;
const std::allocator<int> const_alloc;
test_temporary_allocation_valid(policy(std::allocator<int>()));
test_temporary_allocation_valid(policy(alloc));
test_temporary_allocation_valid(policy(const_alloc));
test_temporary_allocation_valid(policy(&test_memory_resource));
test_temporary_allocation_valid(policy(std::allocator<int>()).after(1));
test_temporary_allocation_valid(policy(alloc).after(1));
test_temporary_allocation_valid(policy(const_alloc).after(1));
}
};
typedef policy_info<thrust::detail::seq_t, thrust::system::detail::sequential::execution_policy> sequential_info;
typedef policy_info<thrust::system::cpp::detail::par_t, thrust::system::cpp::detail::execution_policy> cpp_par_info;
typedef policy_info<thrust::system::omp::detail::par_t, thrust::system::omp::detail::execution_policy> omp_par_info;
typedef policy_info<thrust::system::tbb::detail::par_t, thrust::system::tbb::detail::execution_policy> tbb_par_info;
#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
typedef policy_info<thrust::system::cuda::detail::par_t, thrust::cuda_cub::execute_on_stream_base> cuda_par_info;
#endif
SimpleUnitTest<TestAllocatorAttachment,
unittest::type_list<sequential_info,
#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
cuda_par_info,
#endif
cpp_par_info,
omp_par_info,
tbb_par_info>>
TestAllocatorAttachmentInstance;
|