1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
|
#include "jjml_llm.h"
#include <stddef.h>
/*
* SHARED PLAIN C++ UTILITIES
*/
void jjml_llm_batch_add(struct llama_batch &batch, llama_token id,
llama_pos pos, const std::vector<llama_seq_id> &seq_ids, bool logits) {
batch.token[batch.n_tokens] = id;
batch.pos[batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits[batch.n_tokens] = logits;
batch.n_tokens++;
}
void jjml_llm_batch_clear(struct llama_batch &batch) {
batch.n_tokens = 0;
}
|