1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
|
/**
* Copyright 2020-2023 by XGBoost contributors
*/
#include <gtest/gtest.h>
#include <string>
#include <utility>
#include <vector>
#include "../../../plugin/sycl/common/row_set.h"
#include "../../../plugin/sycl/device_manager.h"
#include "../helpers.h"
namespace xgboost::sycl::common {
TEST(SyclRowSetCollection, AddSplits) {
const size_t num_rows = 16;
DeviceManager device_manager;
auto qu = device_manager.GetQueue(DeviceOrd::SyclDefault());
RowSetCollection row_set_collection;
auto& row_indices = row_set_collection.Data();
row_indices.Resize(qu, num_rows);
size_t* p_row_indices = row_indices.Data();
qu->submit([&](::sycl::handler& cgh) {
cgh.parallel_for<>(::sycl::range<1>(num_rows),
[p_row_indices](::sycl::item<1> pid) {
const size_t idx = pid.get_id(0);
p_row_indices[idx] = idx;
});
}).wait_and_throw();
row_set_collection.Init();
CHECK_EQ(row_set_collection.Size(), 1);
{
size_t nid_test = 0;
auto& elem = row_set_collection[nid_test];
CHECK_EQ(elem.begin, row_indices.Begin());
CHECK_EQ(elem.end, row_indices.End());
CHECK_EQ(elem.node_id , 0);
}
size_t nid = 0;
size_t nid_left = 1;
size_t nid_right = 2;
size_t n_left = 4;
size_t n_right = num_rows - n_left;
row_set_collection.AddSplit(nid, nid_left, nid_right, n_left, n_right);
CHECK_EQ(row_set_collection.Size(), 3);
{
size_t nid_test = 0;
auto& elem = row_set_collection[nid_test];
CHECK_EQ(elem.begin, nullptr);
CHECK_EQ(elem.end, nullptr);
CHECK_EQ(elem.node_id , -1);
}
{
size_t nid_test = 1;
auto& elem = row_set_collection[nid_test];
CHECK_EQ(elem.begin, row_indices.Begin());
CHECK_EQ(elem.end, row_indices.Begin() + n_left);
CHECK_EQ(elem.node_id , nid_test);
}
{
size_t nid_test = 2;
auto& elem = row_set_collection[nid_test];
CHECK_EQ(elem.begin, row_indices.Begin() + n_left);
CHECK_EQ(elem.end, row_indices.End());
CHECK_EQ(elem.node_id , nid_test);
}
}
} // namespace xgboost::sycl::common
|