File: test_socket.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (83 lines) | stat: -rw-r--r-- 2,306 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
/**
 * Copyright 2022-2024, XGBoost Contributors
 */
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>

#include <cerrno>        // EADDRNOTAVAIL
#include <system_error>  // std::error_code, std::system_category

#include "test_worker.h"  // for SocketTest

namespace xgboost::collective {
TEST_F(SocketTest, Basic) {
  SockAddress addr{SockAddrV6::Loopback()};
  ASSERT_TRUE(addr.IsV6());
  addr = SockAddress{SockAddrV4::Loopback()};
  ASSERT_TRUE(addr.IsV4());

  std::string msg{"Skipping IPv6 test"};

  auto run_test = [msg](SockDomain domain) {
    auto server = TCPSocket::Create(domain);
    ASSERT_EQ(server.Domain(), domain);
    std::int32_t port{0};
    auto rc = Success() << [&] {
      return server.BindHost(&port);
    } << [&] {
      return server.Listen();
    };
    SafeColl(rc);

    TCPSocket client;
    if (domain == SockDomain::kV4) {
      auto const& addr = SockAddrV4::Loopback().Addr();
      auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
      SafeColl(rc);
    } else {
      auto const& addr = SockAddrV6::Loopback().Addr();
      auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
      // some environment (docker) has restricted network configuration.
      if (!rc.OK() && rc.Code() == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
        GTEST_SKIP_(msg.c_str());
      }
      ASSERT_EQ(rc, Success()) << rc.Report();
    }
    ASSERT_EQ(client.Domain(), domain);

    auto accepted = server.Accept();
    StringView msg{"Hello world."};
    accepted.Send(msg);

    std::string str;
    rc = client.Recv(&str);
    SafeColl(rc);
    ASSERT_EQ(StringView{str}, msg);
  };

  run_test(SockDomain::kV4);

  if (SkipTest()) {
    GTEST_SKIP_(skip_msg_.c_str());
  }
  run_test(SockDomain::kV6);
}

TEST_F(SocketTest, Bind) {
  auto run = [](SockDomain domain) {
    auto any =
        domain == SockDomain::kV4 ? SockAddrV4::InaddrAny().Addr() : SockAddrV6::InaddrAny().Addr();
    auto sock = TCPSocket::Create(domain);
    std::int32_t port{0};
    auto rc = sock.Bind(any, &port);
    SafeColl(rc);
    ASSERT_NE(port, 0);
  };

  run(SockDomain::kV4);
  if (SkipTest()) {
    GTEST_SKIP_(skip_msg_.c_str());
  }
  run(SockDomain::kV6);
}
}  // namespace xgboost::collective