1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
|
/**
* Copyright 2022-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
#include <cerrno> // EADDRNOTAVAIL
#include <system_error> // std::error_code, std::system_category
#include "test_worker.h" // for SocketTest
namespace xgboost::collective {
TEST_F(SocketTest, Basic) {
SockAddress addr{SockAddrV6::Loopback()};
ASSERT_TRUE(addr.IsV6());
addr = SockAddress{SockAddrV4::Loopback()};
ASSERT_TRUE(addr.IsV4());
std::string msg{"Skipping IPv6 test"};
auto run_test = [msg](SockDomain domain) {
auto server = TCPSocket::Create(domain);
ASSERT_EQ(server.Domain(), domain);
std::int32_t port{0};
auto rc = Success() << [&] {
return server.BindHost(&port);
} << [&] {
return server.Listen();
};
SafeColl(rc);
TCPSocket client;
if (domain == SockDomain::kV4) {
auto const& addr = SockAddrV4::Loopback().Addr();
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
SafeColl(rc);
} else {
auto const& addr = SockAddrV6::Loopback().Addr();
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
// some environment (docker) has restricted network configuration.
if (!rc.OK() && rc.Code() == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
GTEST_SKIP_(msg.c_str());
}
ASSERT_EQ(rc, Success()) << rc.Report();
}
ASSERT_EQ(client.Domain(), domain);
auto accepted = server.Accept();
StringView msg{"Hello world."};
accepted.Send(msg);
std::string str;
rc = client.Recv(&str);
SafeColl(rc);
ASSERT_EQ(StringView{str}, msg);
};
run_test(SockDomain::kV4);
if (SkipTest()) {
GTEST_SKIP_(skip_msg_.c_str());
}
run_test(SockDomain::kV6);
}
TEST_F(SocketTest, Bind) {
auto run = [](SockDomain domain) {
auto any =
domain == SockDomain::kV4 ? SockAddrV4::InaddrAny().Addr() : SockAddrV6::InaddrAny().Addr();
auto sock = TCPSocket::Create(domain);
std::int32_t port{0};
auto rc = sock.Bind(any, &port);
SafeColl(rc);
ASSERT_NE(port, 0);
};
run(SockDomain::kV4);
if (SkipTest()) {
GTEST_SKIP_(skip_msg_.c_str());
}
run(SockDomain::kV6);
}
} // namespace xgboost::collective
|