1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
|
#pragma once
#include "common.h"
#include "llama.h"
#include <string>
#include <unordered_set>
#include <list>
#include <map>
// TODO: prevent including the whole server-common.h as we only use server_tokens
#include "server-common.h"
using json = nlohmann::ordered_json;
enum server_task_type {
SERVER_TASK_TYPE_COMPLETION,
SERVER_TASK_TYPE_EMBEDDING,
SERVER_TASK_TYPE_RERANK,
SERVER_TASK_TYPE_INFILL,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS,
SERVER_TASK_TYPE_SLOT_SAVE,
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
SERVER_TASK_TYPE_GET_LORA,
SERVER_TASK_TYPE_SET_LORA,
};
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
enum task_response_type {
TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
TASK_RESPONSE_TYPE_OAI_CHAT,
TASK_RESPONSE_TYPE_OAI_CMPL,
TASK_RESPONSE_TYPE_OAI_EMBD,
TASK_RESPONSE_TYPE_ANTHROPIC,
};
enum stop_type {
STOP_TYPE_NONE,
STOP_TYPE_EOS,
STOP_TYPE_WORD,
STOP_TYPE_LIMIT,
};
struct task_params {
bool stream = true;
bool include_usage = false;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
bool return_progress = false;
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
int32_t n_cmpl = 1; // number of completions to generate from this prompt
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::map<int, float> lora; // mapping adapter ID -> scale
std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
json to_json(bool only_metrics = false) const;
};
// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs);
};
struct server_task {
int id = -1; // to be filled by server_queue
// TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
size_t index = 0; // used when there are multiple prompts (batch request)
// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
int id_slot = -1;
// used by parallel sampling (multiple completions from same prompt)
size_t n_children = 0; // number of tasks reusing this prompt
int id_parent = -1;
// used by SERVER_TASK_TYPE_INFERENCE
task_params params;
server_tokens tokens;
// only used by CLI, this delegates the tokenization to the server
json cli_input = nullptr;
std::vector<raw_buffer> cli_files;
server_task_type type;
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
struct slot_action {
int slot_id;
std::string filename;
std::string filepath;
};
slot_action slot_action;
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;
// used by SERVER_TASK_TYPE_SET_LORA
std::map<int, float> set_lora; // mapping adapter ID -> scale
server_task() = default;
server_task(server_task_type type) : type(type) {}
int32_t n_tokens() const {
return tokens.size();
}
static task_params params_from_json_cmpl(
const llama_vocab * vocab,
const common_params & params_base,
const int n_ctx_slot,
const json & data);
// utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
std::unordered_set<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids.insert(tasks[i].id);
}
return ids;
}
server_task create_child(int id_parent, int id_child) const {
server_task copy;
copy.id = id_child;
copy.id_parent = id_parent;
copy.params = params;
copy.type = type;
copy.tokens = tokens.clone();
return copy;
}
// the task will be moved into queue, then onto slots
// however, the state must be kept by caller (e.g., HTTP thread)
task_result_state create_state() const {
return task_result_state(params.oaicompat_chat_syntax);
}
};
struct result_timings {
int32_t cache_n = -1;
int32_t prompt_n = -1;
double prompt_ms;
double prompt_per_token_ms;
double prompt_per_second;
int32_t predicted_n = -1;
double predicted_ms;
double predicted_per_token_ms;
double predicted_per_second;
// Optional speculative metrics - only included when > 0
int32_t draft_n = 0;
int32_t draft_n_accepted = 0;
json to_json() const;
};
struct result_prompt_progress {
int32_t total = 0;
int32_t cache = 0;
int32_t processed = 0;
int64_t time_ms = 0;
json to_json() const;
};
struct server_task_result {
int id = -1;
int id_slot = -1;
// TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader
size_t index = 0; // to be used for batched tasks
virtual bool is_error() {
// only used by server_task_result_error
return false;
}
virtual bool is_stop() {
// only used by server_task_result_cmpl_*
return true;
}
virtual void update(task_result_state &) {
// only used by server_task_result_cmpl_*
}
virtual json to_json() = 0;
virtual ~server_task_result() = default;
};
// using shared_ptr for polymorphism of server_task_result
using server_task_result_ptr = std::unique_ptr<server_task_result>;
struct completion_token_output {
llama_token tok;
float prob;
std::string text_to_send;
struct prob_info {
llama_token tok;
std::string txt;
float prob;
};
std::vector<prob_info> probs;
json to_json(bool post_sampling_probs) const;
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
static float logarithm(float x);
static std::vector<unsigned char> str_to_bytes(const std::string & str);
};
struct server_task_result_cmpl_final : server_task_result {
std::string content;
llama_tokens tokens;
bool stream;
bool include_usage;
result_timings timings;
std::string prompt;
bool truncated;
int32_t n_decoded;
int32_t n_prompt_tokens;
int32_t n_tokens_cached;
bool has_new_line;
std::string stopping_word;
stop_type stop = STOP_TYPE_NONE;
bool post_sampling_probs;
std::vector<completion_token_output> probs_output;
std::vector<std::string> response_fields;
task_params generation_params;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg; // to be populated by update()
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
virtual bool is_stop() override {
return true; // in stream mode, final responses are considered stop
}
virtual json to_json() override;
virtual void update(task_result_state & state) override {
is_updated = true;
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
}
json to_json_non_oaicompat();
json to_json_oaicompat();
json to_json_oaicompat_chat();
json to_json_oaicompat_chat_stream();
json to_json_anthropic();
json to_json_anthropic_stream();
};
struct server_task_result_cmpl_partial : server_task_result {
std::string content;
llama_tokens tokens;
int32_t n_decoded;
int32_t n_prompt_tokens;
bool post_sampling_probs;
bool is_progress = false;
completion_token_output prob_output;
result_timings timings;
result_prompt_progress progress;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
virtual bool is_stop() override {
return false; // in stream mode, partial responses are not considered stop
}
virtual json to_json() override;
virtual void update(task_result_state & state) override {
is_updated = true;
state.update_chat_msg(content, true, oaicompat_msg_diffs);
}
json to_json_non_oaicompat();
json to_json_oaicompat();
json to_json_oaicompat_chat();
json to_json_anthropic();
};
struct server_task_result_embd : server_task_result {
std::vector<std::vector<float>> embedding;
int32_t n_tokens;
// response formatting
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
virtual json to_json() override;
json to_json_non_oaicompat();
json to_json_oaicompat();
};
struct server_task_result_rerank : server_task_result {
float score = -1e6;
int32_t n_tokens;
virtual json to_json() override;
};
struct server_task_result_error : server_task_result {
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;
// for ERROR_TYPE_EXCEED_CONTEXT_SIZE
int32_t n_prompt_tokens = 0;
int32_t n_ctx = 0;
virtual bool is_error() override {
return true;
}
virtual json to_json() override;
};
struct server_task_result_metrics : server_task_result {
int n_idle_slots;
int n_processing_slots;
int n_tasks_deferred;
int64_t t_start;
// TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
uint64_t n_prompt_tokens_processed_total = 0;
uint64_t t_prompt_processing_total = 0;
uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0;
uint64_t n_tokens_max = 0;
uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0;
uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
uint64_t n_decode_total = 0;
uint64_t n_busy_slots_total = 0;
// while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
// therefore, we use json to temporarily store the slot.to_json() result
json slots_data = json::array();
virtual json to_json() override;
};
struct server_task_result_slot_save_load : server_task_result {
std::string filename;
bool is_save; // true = save, false = load
size_t n_tokens;
size_t n_bytes;
double t_ms;
virtual json to_json() override;
};
struct server_task_result_slot_erase : server_task_result {
size_t n_erased;
virtual json to_json() override;
};
struct server_task_result_get_lora : server_task_result {
struct lora {
common_adapter_lora_info info;
std::string alora_invocation_string;
llama_tokens alora_invocation_tokens;
};
std::vector<lora> loras;
virtual json to_json() override;
};
struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override;
};
struct server_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
};
struct server_prompt {
server_tokens tokens;
std::vector<uint8_t> data;
std::list<server_prompt_checkpoint> checkpoints;
size_t size() const {
size_t res = data.size();
for (const auto & checkpoint : checkpoints) {
res += checkpoint.size();
}
return res;
}
int n_tokens() const {
return tokens.size();
}
server_prompt clone() const {
return server_prompt {
tokens.clone(),
data,
checkpoints
};
}
};
struct server_prompt_cache {
server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
this->limit_tokens = limit_tokens;
}
std::list<server_prompt> states;
// in bytes, 0 = no limit
size_t limit_size = 0;
// in tokens, 0 = no limit
size_t limit_tokens = 0;
size_t size() const;
size_t n_tokens() const;
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
void update();
};
|