1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
|
#include "duckdb.hpp"
#include "duckdb/common/types/column/column_data_collection.hpp"
#include "duckdb/optimizer/optimizer_extension.hpp"
#include "duckdb/planner/operator/logical_column_data_get.hpp"
#include "duckdb/planner/operator/logical_get.hpp"
#include "duckdb/common/serializer/binary_serializer.hpp"
#include "duckdb/common/serializer/binary_deserializer.hpp"
#include "duckdb/common/serializer/memory_stream.hpp"
using namespace duckdb;
// whatever
#include <signal.h>
#include <sys/mman.h>
#include <unistd.h>
#include <stdio.h>
#include <netdb.h>
#include <netinet/in.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <arpa/inet.h>
#ifdef __MVS__
#define _XOPEN_SOURCE_EXTENDED 1
#include <strings.h>
#endif
class WaggleExtension : public OptimizerExtension {
public:
WaggleExtension() {
optimize_function = WaggleOptimizeFunction;
}
static bool HasParquetScan(LogicalOperator &op) {
if (op.type == LogicalOperatorType::LOGICAL_GET) {
auto &get = op.Cast<LogicalGet>();
return get.function.name == "parquet_scan";
}
for (auto &child : op.children) {
if (HasParquetScan(*child)) {
return true;
}
}
return false;
}
static void WriteChecked(int sockfd, void *data, idx_t write_size) {
auto bytes_written = write(sockfd, data, write_size);
if (bytes_written < 0) {
throw InternalException("Failed to write \"%lld\" bytes to socket: %s", write_size, strerror(errno));
}
if (idx_t(bytes_written) != write_size) {
throw InternalException("Failed to write \"%llu\" bytes from socket - wrote %llu instead", write_size,
bytes_written);
}
}
static void ReadChecked(int sockfd, void *data, idx_t read_size) {
auto bytes_read = read(sockfd, data, read_size);
if (bytes_read < 0) {
throw InternalException("Failed to read \"%lld\" bytes from socket: %s", read_size, strerror(errno));
}
if (idx_t(bytes_read) != read_size) {
throw InternalException("Failed to read \"%llu\" bytes from socket - read %llu instead", read_size,
bytes_read);
}
}
static void WaggleOptimizeFunction(OptimizerExtensionInput &input, duckdb::unique_ptr<LogicalOperator> &plan) {
if (!HasParquetScan(*plan)) {
return;
}
// rpc
auto &context = input.context;
Value host, port;
if (!context.TryGetCurrentSetting("waggle_location_host", host) ||
!context.TryGetCurrentSetting("waggle_location_port", port)) {
throw InvalidInputException("Need the parameters damnit");
}
// socket create and verification
auto sockfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (sockfd == -1) {
throw InternalException("Failed to create socket");
}
struct sockaddr_in servaddr;
bzero(&servaddr, sizeof(servaddr));
// assign IP, PORT
servaddr.sin_family = AF_INET;
auto host_string = host.ToString();
servaddr.sin_addr.s_addr = inet_addr(host_string.c_str());
servaddr.sin_port = htons(port.GetValue<int32_t>());
// connect the client socket to server socket
if (connect(sockfd, (struct sockaddr *)&servaddr, sizeof(servaddr)) != 0) {
throw IOException("Failed to connect socket %s", string(strerror(errno)));
}
Allocator allocator;
MemoryStream stream(allocator);
BinarySerializer serializer(stream);
serializer.Begin();
plan->Serialize(serializer);
serializer.End();
auto data = stream.GetData();
idx_t len = stream.GetPosition();
WriteChecked(sockfd, &len, sizeof(idx_t));
WriteChecked(sockfd, data, len);
auto chunk_collection = make_uniq<ColumnDataCollection>(Allocator::DefaultAllocator());
idx_t n_chunks;
ReadChecked(sockfd, &n_chunks, sizeof(idx_t));
for (idx_t i = 0; i < n_chunks; i++) {
idx_t chunk_len;
ReadChecked(sockfd, &chunk_len, sizeof(idx_t));
auto buffer = malloc(chunk_len);
D_ASSERT(buffer);
ReadChecked(sockfd, buffer, chunk_len);
MemoryStream source(data_ptr_cast(buffer), chunk_len);
DataChunk chunk;
BinaryDeserializer deserializer(source);
deserializer.Begin();
chunk.Deserialize(deserializer);
deserializer.End();
chunk_collection->Initialize(chunk.GetTypes());
chunk_collection->Append(chunk);
free(buffer);
}
auto types = chunk_collection->Types();
plan = make_uniq<LogicalColumnDataGet>(0, types, std::move(chunk_collection));
len = 0;
(void)len;
WriteChecked(sockfd, &len, sizeof(idx_t));
// close the socket
close(sockfd);
}
};
//===--------------------------------------------------------------------===//
// Extension load + setup
//===--------------------------------------------------------------------===//
extern "C" {
DUCKDB_CPP_EXTENSION_ENTRY(loadable_extension_optimizer_demo, loader) {
auto &db = loader.GetDatabaseInstance();
Connection con(db);
// add a parser extension
auto &config = DBConfig::GetConfig(db);
OptimizerExtension::Register(config, WaggleExtension());
config.AddExtensionOption("waggle_location_host", "host for remote callback", LogicalType::VARCHAR);
config.AddExtensionOption("waggle_location_port", "port for remote callback", LogicalType::INTEGER);
}
}
|