1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
#pragma once
#include <poll.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/un.h>
#include <unistd.h>
#include <cstddef>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <libshm/alloc_info.h>
#include <libshm/err.h>
class Socket {
public:
int socket_fd;
protected:
Socket() {
SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
}
Socket(const Socket& other) = delete;
Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) {
other.socket_fd = -1;
};
explicit Socket(int fd) : socket_fd(fd) {}
virtual ~Socket() {
if (socket_fd != -1)
close(socket_fd);
}
struct sockaddr_un prepare_address(const char* path) {
struct sockaddr_un address;
address.sun_family = AF_UNIX;
strcpy(address.sun_path, path);
return address;
}
// Implemented based on https://man7.org/linux/man-pages/man7/unix.7.html
size_t address_length(struct sockaddr_un address) {
return offsetof(sockaddr_un, sun_path) + strlen(address.sun_path) + 1;
}
void recv(void* _buffer, size_t num_bytes) {
char* buffer = (char*)_buffer;
size_t bytes_received = 0;
ssize_t step_received;
struct pollfd pfd = {0};
pfd.fd = socket_fd;
pfd.events = POLLIN;
while (bytes_received < num_bytes) {
SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
if (pfd.revents & POLLIN) {
SYSCHECK_ERR_RETURN_NEG1(
step_received =
::read(socket_fd, buffer, num_bytes - bytes_received));
if (step_received == 0)
throw std::runtime_error("Other end has closed the connection");
bytes_received += step_received;
buffer += step_received;
} else if (pfd.revents & (POLLERR | POLLHUP)) {
throw std::runtime_error(
"An error occurred while waiting for the data");
} else {
throw std::runtime_error(
"Shared memory manager connection has timed out");
}
}
}
void send(const void* _buffer, size_t num_bytes) {
const char* buffer = (const char*)_buffer;
size_t bytes_sent = 0;
ssize_t step_sent;
while (bytes_sent < num_bytes) {
SYSCHECK_ERR_RETURN_NEG1(
step_sent = ::write(socket_fd, buffer, num_bytes));
bytes_sent += step_sent;
buffer += step_sent;
}
}
};
class ManagerSocket : public Socket {
public:
explicit ManagerSocket(int fd) : Socket(fd) {}
AllocInfo receive() {
AllocInfo info;
recv(&info, sizeof(info));
return info;
}
void confirm() {
send("OK", 2);
}
};
class ManagerServerSocket : public Socket {
public:
explicit ManagerServerSocket(const std::string& path) {
socket_path = path;
try {
struct sockaddr_un address = prepare_address(path.c_str());
size_t len = address_length(address);
SYSCHECK_ERR_RETURN_NEG1(
bind(socket_fd, (struct sockaddr*)&address, len));
SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
} catch (std::exception& e) {
SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
throw;
}
}
void remove() {
struct stat file_stat;
if (fstat(socket_fd, &file_stat) == 0)
SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str()));
}
virtual ~ManagerServerSocket() {
unlink(socket_path.c_str());
}
ManagerSocket accept() {
int client_fd;
struct sockaddr_un addr;
socklen_t addr_len = sizeof(addr);
SYSCHECK_ERR_RETURN_NEG1(
client_fd = ::accept(socket_fd, (struct sockaddr*)&addr, &addr_len));
return ManagerSocket(client_fd);
}
std::string socket_path;
};
class ClientSocket : public Socket {
public:
explicit ClientSocket(const std::string& path) {
try {
struct sockaddr_un address = prepare_address(path.c_str());
size_t len = address_length(address);
SYSCHECK_ERR_RETURN_NEG1(
connect(socket_fd, (struct sockaddr*)&address, len));
} catch (std::exception& e) {
SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
throw;
}
}
void register_allocation(AllocInfo& info) {
char buffer[3] = {0, 0, 0};
send(&info, sizeof(info));
recv(buffer, 2);
if (strcmp(buffer, "OK") != 0)
throw std::runtime_error(
"Shared memory manager didn't respond with an OK");
}
void register_deallocation(AllocInfo& info) {
send(&info, sizeof(info));
}
};
|