File: test_worker.cuh

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (32 lines) | stat: -rw-r--r-- 894 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/**
 * Copyright 2023, XGBoost Contributors
 */
#pragma once
#include <memory>  // for shared_ptr

#include "../../../src/collective/coll.h"  // for Coll
#include "../../../src/collective/comm.h"  // for Comm
#include "test_worker.h"
#include "xgboost/context.h"  // for Context

namespace xgboost::collective {
class NCCLWorkerForTest : public WorkerForTest {
 protected:
  std::shared_ptr<Coll> coll_;
  std::shared_ptr<xgboost::collective::Comm> nccl_comm_;
  std::shared_ptr<Coll> nccl_coll_;
  Context ctx_;

 public:
  using WorkerForTest::WorkerForTest;

  void Setup() {
    ctx_ = MakeCUDACtx(comm_.Rank());
    coll_.reset(new Coll{});
    nccl_comm_.reset(this->comm_.MakeCUDAVar(&ctx_, coll_));
    nccl_coll_.reset(coll_->MakeCUDAVar());
    ASSERT_EQ(comm_.World(), nccl_comm_->World());
    ASSERT_EQ(comm_.Rank(), nccl_comm_->Rank());
  }
};
}  // namespace xgboost::collective