1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
// Copyright (c) Meta Platforms, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <chrono>
#include <cstdint>
#include <memory>
#include <string>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <torch/csrc/distributed/c10d/Backoff.hpp>
#include <torch/csrc/distributed/c10d/exception.h>
namespace c10d::detail {
class SocketOptions {
public:
SocketOptions& prefer_ipv6(bool value) noexcept {
prefer_ipv6_ = value;
return *this;
}
bool prefer_ipv6() const noexcept {
return prefer_ipv6_;
}
SocketOptions& connect_timeout(std::chrono::milliseconds value) noexcept {
connect_timeout_ = value;
return *this;
}
std::chrono::milliseconds connect_timeout() const noexcept {
return connect_timeout_;
}
// Sets the backoff policy to use for socket connect ops.
SocketOptions& connect_backoff(std::shared_ptr<Backoff> value) noexcept {
connect_backoff_ = std::move(value);
return *this;
}
const std::shared_ptr<Backoff>& connect_backoff() const noexcept {
return connect_backoff_;
}
private:
bool prefer_ipv6_ = true;
std::chrono::milliseconds connect_timeout_{std::chrono::seconds{30}};
std::shared_ptr<Backoff> connect_backoff_{
std::make_shared<FixedBackoff>(std::chrono::milliseconds(1000))};
};
class SocketImpl;
class Socket {
public:
// This function initializes the underlying socket library and must be called
// before any other socket function.
static void initialize();
static Socket listen(std::uint16_t port, const SocketOptions& opts = {});
static Socket listenFromFd(int fd, std::uint16_t expected_port);
static Socket connect(
const std::string& host,
std::uint16_t port,
const SocketOptions& opts = {});
Socket() noexcept = default;
Socket(const Socket& other) = delete;
Socket& operator=(const Socket& other) = delete;
Socket(Socket&& other) noexcept;
Socket& operator=(Socket&& other) noexcept;
~Socket();
Socket accept() const;
int handle() const noexcept;
std::uint16_t port() const;
bool waitForInput(std::chrono::milliseconds timeout);
std::string repr() const;
private:
explicit Socket(std::unique_ptr<SocketImpl>&& impl) noexcept;
std::unique_ptr<SocketImpl> impl_;
};
} // namespace c10d::detail
|