File: transform.h

package info (click to toggle)
xgboost 3.0.4-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 13,848 kB
  • sloc: cpp: 67,603; python: 35,537; java: 4,676; ansic: 1,426; sh: 1,352; xml: 1,226; makefile: 204; javascript: 19
file content (35 lines) | stat: -rw-r--r-- 926 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
/**
 * Copyright 2021-2024, XGBoost Contributors
 * \file transform.h
 */
#ifndef PLUGIN_SYCL_COMMON_TRANSFORM_H_
#define PLUGIN_SYCL_COMMON_TRANSFORM_H_

#include "../device_manager.h"

#include <sycl/sycl.hpp>

namespace xgboost {
namespace sycl {
namespace common {

template <typename Functor, typename... SpanType>
void LaunchSyclKernel(DeviceOrd device, Functor&& _func, xgboost::common::Range _range,
                      SpanType... _spans) {
  sycl::DeviceManager device_manager;
  auto* qu = device_manager.GetQueue(device);

  size_t size = *(_range.end());
  qu->submit([&](::sycl::handler& cgh) {
    cgh.parallel_for<>(::sycl::range<1>(size),
                       [=](::sycl::id<1> pid) {
      const size_t idx = pid[0];
      const_cast<Functor&&>(_func)(idx, _spans...);
    });
  }).wait();
}

}  // namespace common
}  // namespace sycl
}  // namespace xgboost
#endif  // PLUGIN_SYCL_COMMON_TRANSFORM_H_