File: set_intersection_by_key.cu

package info (click to toggle)
libthrust 1.17.2-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,900 kB
  • sloc: ansic: 29,519; cpp: 23,989; python: 1,421; sh: 811; perl: 460; makefile: 112
file content (137 lines) | stat: -rw-r--r-- 4,296 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include <unittest/unittest.h>
#include <thrust/set_operations.h>
#include <thrust/execution_policy.h>


template<typename ExecutionPolicy, typename Iterator1, typename Iterator2, typename Iterator3, typename Iterator4, typename Iterator5, typename Iterator6>
__global__
void set_intersection_by_key_kernel(ExecutionPolicy exec,
                                    Iterator1 keys_first1, Iterator1 keys_last1,
                                    Iterator2 keys_first2, Iterator2 keys_last2,
                                    Iterator3 values_first1,
                                    Iterator4 keys_result,
                                    Iterator5 values_result,
                                    Iterator6 result)
{
  *result = thrust::set_intersection_by_key(exec, keys_first1, keys_last1, keys_first2, keys_last2, values_first1, keys_result, values_result);
}


template<typename ExecutionPolicy>
void TestSetIntersectionByKeyDevice(ExecutionPolicy exec)
{
  typedef thrust::device_vector<int> Vector;
  typedef typename Vector::iterator Iterator;

  Vector a_key(3), b_key(4);
  Vector a_val(3);

  a_key[0] = 0; a_key[1] = 2; a_key[2] = 4;
  a_val[0] = 0; a_val[1] = 0; a_val[2] = 0;

  b_key[0] = 0; b_key[1] = 3; b_key[2] = 3; b_key[3] = 4;

  Vector ref_key(2), ref_val(2);
  ref_key[0] = 0; ref_key[1] = 4;
  ref_val[0] = 0; ref_val[1] = 0;

  Vector result_key(2), result_val(2);

  typedef thrust::pair<Iterator,Iterator> iter_pair;
  thrust::device_vector<iter_pair> end_vec(1);

  set_intersection_by_key_kernel<<<1,1>>>(exec,
                                          a_key.begin(), a_key.end(),
                                          b_key.begin(), b_key.end(),
                                          a_val.begin(),
                                          result_key.begin(),
                                          result_val.begin(),
                                          end_vec.begin());
  cudaError_t const err = cudaDeviceSynchronize();
  ASSERT_EQUAL(cudaSuccess, err);

  thrust::pair<Iterator,Iterator> end = end_vec.front();

  ASSERT_EQUAL_QUIET(result_key.end(), end.first);
  ASSERT_EQUAL_QUIET(result_val.end(), end.second);
  ASSERT_EQUAL(ref_key, result_key);
  ASSERT_EQUAL(ref_val, result_val);
}


void TestSetIntersectionByKeyDeviceSeq()
{
  TestSetIntersectionByKeyDevice(thrust::seq);
}
DECLARE_UNITTEST(TestSetIntersectionByKeyDeviceSeq);


void TestSetIntersectionByKeyDeviceDevice()
{
  TestSetIntersectionByKeyDevice(thrust::device);
}
DECLARE_UNITTEST(TestSetIntersectionByKeyDeviceDevice);


void TestSetIntersectionByKeyDeviceNoSync()
{
  TestSetIntersectionByKeyDevice(thrust::cuda::par_nosync);
}
DECLARE_UNITTEST(TestSetIntersectionByKeyDeviceNoSync);


template<typename ExecutionPolicy>
void TestSetIntersectionByKeyCudaStreams(ExecutionPolicy policy)
{
  typedef thrust::device_vector<int> Vector;
  typedef Vector::iterator Iterator;

  Vector a_key(3), b_key(4);
  Vector a_val(3);

  a_key[0] = 0; a_key[1] = 2; a_key[2] = 4;
  a_val[0] = 0; a_val[1] = 0; a_val[2] = 0;

  b_key[0] = 0; b_key[1] = 3; b_key[2] = 3; b_key[3] = 4;

  Vector ref_key(2), ref_val(2);
  ref_key[0] = 0; ref_key[1] = 4;
  ref_val[0] = 0; ref_val[1] = 0;

  Vector result_key(2), result_val(2);

  cudaStream_t s;
  cudaStreamCreate(&s);

  auto streampolicy = policy.on(s);

  thrust::pair<Iterator,Iterator> end =
    thrust::set_intersection_by_key(streampolicy,
                                    a_key.begin(), a_key.end(),
                                    b_key.begin(), b_key.end(),
                                    a_val.begin(),
                                    result_key.begin(),
                                    result_val.begin());
  cudaStreamSynchronize(s);

  ASSERT_EQUAL_QUIET(result_key.end(), end.first);
  ASSERT_EQUAL_QUIET(result_val.end(), end.second);
  ASSERT_EQUAL(ref_key, result_key);
  ASSERT_EQUAL(ref_val, result_val);

  cudaStreamDestroy(s);
}

void TestSetIntersectionByKeyCudaStreamsSync()
{
  TestSetIntersectionByKeyCudaStreams(thrust::cuda::par);
}
DECLARE_UNITTEST(TestSetIntersectionByKeyCudaStreamsSync);


void TestSetIntersectionByKeyCudaStreamsNoSync()
{
  TestSetIntersectionByKeyCudaStreams(thrust::cuda::par_nosync);
}
DECLARE_UNITTEST(TestSetIntersectionByKeyCudaStreamsNoSync);