1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
|
#pragma once
#include "server-task.h"
#include <condition_variable>
#include <deque>
#include <mutex>
#include <vector>
#include <unordered_set>
// struct for managing server tasks
// in most cases, use server_response_reader to post new tasks and retrieve results
struct server_queue {
private:
int id = 0;
bool running = false;
bool sleeping = false;
bool req_stop_sleeping = false;
int64_t time_last_task = 0;
// queues
std::deque<server_task> queue_tasks;
std::deque<server_task> queue_tasks_deferred;
std::mutex mutex_tasks;
std::condition_variable condition_tasks;
// callback functions
std::function<void(server_task &&)> callback_new_task;
std::function<void(void)> callback_update_slots;
std::function<void(bool)> callback_sleeping_state;
public:
// Add a new task to the end of the queue
int post(server_task && task, bool front = false);
// multi-task version of post()
int post(std::vector<server_task> && tasks, bool front = false);
// Add a new task, but defer until one slot is available
void defer(server_task && task);
// Get the next id for creating a new task
int get_new_id();
// Call when the state of one slot is changed, it will move one task from deferred to main queue
void pop_deferred_task();
// if sleeping, request exiting sleep state and wait until it is done
// returns immediately if not sleeping
void wait_until_no_sleep();
bool is_sleeping() {
std::unique_lock<std::mutex> lock(mutex_tasks);
return sleeping;
}
// end the start_loop routine
void terminate();
/**
* Main loop consists of these steps:
* - Wait until a new task arrives
* - Process the task (i.e. maybe copy data into slot)
* - Check if multitask is finished
* - Update all slots
*
* Sleeping procedure (disabled if idle_sleep_ms < 0):
* - If there is no task after idle_sleep_ms, enter sleeping state
* - Call callback_sleeping_state(true)
* - Wait until req_stop_sleeping is set to true
* - Call callback_sleeping_state(false)
* - Exit sleeping state
*/
void start_loop(int64_t idle_sleep_ms = -1);
// for metrics
size_t queue_tasks_deferred_size() {
std::unique_lock<std::mutex> lock(mutex_tasks);
return queue_tasks_deferred.size();
}
//
// Functions below are not thread-safe, must only be used before start_loop() is called
//
// Register function to process a new task
void on_new_task(std::function<void(server_task &&)> callback) {
callback_new_task = std::move(callback);
}
// Register the function to be called when all slots data is ready to be processed
void on_update_slots(std::function<void(void)> callback) {
callback_update_slots = std::move(callback);
}
// Register callback for sleeping state change
// note: when entering sleeping state, the callback is called AFTER sleeping is set to true
// when leaving sleeping state, the callback is called BEFORE sleeping is set to false
void on_sleeping_state(std::function<void(bool)> callback) {
callback_sleeping_state = std::move(callback);
}
private:
void cleanup_pending_task(int id_target);
};
// struct for managing server responses
// in most cases, use server_response_reader to retrieve results
struct server_response {
private:
bool running = true;
// for keeping track of all tasks waiting for the result
std::unordered_set<int> waiting_task_ids;
// the main result queue (using ptr for polymorphism)
std::vector<server_task_result_ptr> queue_results;
std::mutex mutex_results;
std::condition_variable condition_results;
public:
// add the id_task to the list of tasks waiting for response
void add_waiting_task_id(int id_task);
void add_waiting_tasks(const std::vector<server_task> & tasks);
// when the request is finished, we can remove task associated with it
void remove_waiting_task_id(int id_task);
// remove multiple tasks from waiting list
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks);
// This function blocks the thread until there is a response for one of the id_tasks
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks);
// same as recv(), but have timeout in seconds
// if timeout is reached, nullptr is returned
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout);
// single-task version of recv()
server_task_result_ptr recv(int id_task);
// Send a new result to a waiting id_task
void send(server_task_result_ptr && result);
// terminate the waiting loop
void terminate();
};
// utility class to make working with server_queue and server_response easier
// it provides a generator-like API for server responses
// support pooling connection state and aggregating multiple results
struct server_response_reader {
std::unordered_set<int> id_tasks;
server_queue & queue_tasks;
server_response & queue_results;
size_t received_count = 0;
bool cancelled = false;
int polling_interval_seconds;
// tracking generation state and partial tool calls
// only used by streaming completions
std::vector<task_result_state> states;
// should_stop function will be called each polling_interval_seconds
server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
: queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
~server_response_reader() {
stop();
}
int get_new_id() {
return queue_tasks.get_new_id();
}
// if front = true, the task will be posted to the front of the queue (high priority)
void post_task(server_task && task, bool front = false);
void post_tasks(std::vector<server_task> && tasks, bool front = false);
bool has_next() const;
// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr next(const std::function<bool()> & should_stop);
struct batch_response {
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
std::vector<server_task_result_ptr> results;
server_task_result_ptr error; // nullptr if no error
};
// aggregate multiple results
batch_response wait_for_all(const std::function<bool()> & should_stop);
void stop();
};
|