1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
|
/**
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gtest/gtest.h> // for ASSERT_TRUE, ASSERT_EQ
#include <xgboost/collective/socket.h> // for TCPSocket, Connect, SocketFinalize, SocketStartup
#include <xgboost/string_view.h> // for StringView
#include <chrono> // for seconds
#include <cstdint> // for int8_t
#include <memory> // for make_shared, shared_ptr
#include <system_error> // for make_error_code, errc
#include <utility> // for pair
#include <vector> // for vector
#include "../../../src/collective/loop.h" // for Loop
namespace xgboost::collective {
namespace {
class LoopTest : public ::testing::Test {
protected:
std::pair<TCPSocket, TCPSocket> pair_;
std::shared_ptr<Loop> loop_;
protected:
void SetUp() override {
system::SocketStartup();
std::chrono::seconds timeout{1};
auto domain = SockDomain::kV4;
pair_.first = TCPSocket::Create(domain);
std::int32_t port{0};
auto rc = Success() << [&] {
return pair_.first.BindHost(&port);
} << [&] {
return pair_.first.Listen();
};
SafeColl(rc);
auto const& addr = SockAddrV4::Loopback().Addr();
rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second);
SafeColl(rc);
rc = pair_.second.NonBlocking(true);
SafeColl(rc);
pair_.first = pair_.first.Accept();
rc = pair_.first.NonBlocking(true);
SafeColl(rc);
loop_ = std::shared_ptr<Loop>{new Loop{timeout}};
}
void TearDown() override {
pair_ = decltype(pair_){};
system::SocketFinalize();
}
};
} // namespace
TEST_F(LoopTest, Timeout) {
std::vector<std::int8_t> data(1);
Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &pair_.second, 0};
loop_->Submit(std::move(op));
auto rc = loop_->Block();
ASSERT_FALSE(rc.OK());
ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out)) << rc.Report();
}
TEST_F(LoopTest, Op) {
TCPSocket& send = pair_.first;
TCPSocket& recv = pair_.second;
std::vector<std::int8_t> wbuf(1, 1);
std::vector<std::int8_t> rbuf(1, 0);
Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0};
Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0};
loop_->Submit(std::move(wop));
loop_->Submit(std::move(rop));
auto rc = loop_->Block();
SafeColl(rc);
ASSERT_EQ(rbuf[0], wbuf[0]);
}
TEST_F(LoopTest, Block) {
// We need to ensure that a blocking call doesn't go unanswered.
auto op = Loop::Op::Sleep(2);
common::Timer t;
t.Start();
loop_->Submit(std::move(op));
t.Stop();
// submit is non-blocking
ASSERT_LT(t.ElapsedSeconds(), 1);
t.Start();
auto rc = loop_->Block();
t.Stop();
SafeColl(rc);
ASSERT_GE(t.ElapsedSeconds(), 1);
}
} // namespace xgboost::collective
|