File: test_loop.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (104 lines) | stat: -rw-r--r-- 2,793 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/**
 * Copyright 2023-2024, XGBoost Contributors
 */
#include <gtest/gtest.h>                // for ASSERT_TRUE, ASSERT_EQ
#include <xgboost/collective/socket.h>  // for TCPSocket, Connect, SocketFinalize, SocketStartup
#include <xgboost/string_view.h>        // for StringView

#include <chrono>        // for seconds
#include <cstdint>       // for int8_t
#include <memory>        // for make_shared, shared_ptr
#include <system_error>  // for make_error_code, errc
#include <utility>       // for pair
#include <vector>        // for vector

#include "../../../src/collective/loop.h"  // for Loop

namespace xgboost::collective {
namespace {
class LoopTest : public ::testing::Test {
 protected:
  std::pair<TCPSocket, TCPSocket> pair_;
  std::shared_ptr<Loop> loop_;

 protected:
  void SetUp() override {
    system::SocketStartup();
    std::chrono::seconds timeout{1};

    auto domain = SockDomain::kV4;
    pair_.first = TCPSocket::Create(domain);
    std::int32_t port{0};
    auto rc = Success() << [&] {
      return pair_.first.BindHost(&port);
    } << [&] {
      return pair_.first.Listen();
    };
    SafeColl(rc);

    auto const& addr = SockAddrV4::Loopback().Addr();
    rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second);
    SafeColl(rc);
    rc = pair_.second.NonBlocking(true);
    SafeColl(rc);

    pair_.first = pair_.first.Accept();
    rc = pair_.first.NonBlocking(true);
    SafeColl(rc);

    loop_ = std::shared_ptr<Loop>{new Loop{timeout}};
  }

  void TearDown() override {
    pair_ = decltype(pair_){};
    system::SocketFinalize();
  }
};
}  // namespace

TEST_F(LoopTest, Timeout) {
  std::vector<std::int8_t> data(1);
  Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &pair_.second, 0};
  loop_->Submit(std::move(op));
  auto rc = loop_->Block();
  ASSERT_FALSE(rc.OK());
  ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out)) << rc.Report();
}

TEST_F(LoopTest, Op) {
  TCPSocket& send = pair_.first;
  TCPSocket& recv = pair_.second;

  std::vector<std::int8_t> wbuf(1, 1);
  std::vector<std::int8_t> rbuf(1, 0);

  Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0};
  Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0};

  loop_->Submit(std::move(wop));
  loop_->Submit(std::move(rop));

  auto rc = loop_->Block();
  SafeColl(rc);

  ASSERT_EQ(rbuf[0], wbuf[0]);
}

TEST_F(LoopTest, Block) {
  // We need to ensure that a blocking call doesn't go unanswered.
  auto op = Loop::Op::Sleep(2);

  common::Timer t;
  t.Start();
  loop_->Submit(std::move(op));
  t.Stop();
  // submit is non-blocking
  ASSERT_LT(t.ElapsedSeconds(), 1);

  t.Start();
  auto rc = loop_->Block();
  t.Stop();
  SafeColl(rc);
  ASSERT_GE(t.ElapsedSeconds(), 1);
}
}  // namespace xgboost::collective