File: jjml_llm_embedding.cpp

package info (click to toggle)
libjjml-java 1.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,104 kB
  • sloc: java: 5,607; cpp: 1,767; sh: 354; makefile: 31
file content (155 lines) | stat: -rw-r--r-- 4,486 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include <cmath>

#include <llama.h>

#include <argeo/jni/argeo_jni.h>

#include "org_argeo_jjml_llm_LlamaCppEmbeddingProcessor.h" // IWYU pragma: keep

#include "jjml_llm.h"
#include "org_argeo_jjml_llm_.h"

/*
 * EMBEDDING
 */
// from llama.cpp's common llama_embd_normalize
static void embd_normalize(const float *inp, float *out, int n, int embd_norm) {
	double sum = 0.0;

	switch (embd_norm) {
	case -1: // no normalisation
		sum = 1.0;
		break;
	case 0: // max absolute
		for (int i = 0; i < n; i++) {
			if (sum < std::abs(inp[i]))
				sum = std::abs(inp[i]);
		}
		sum /= 32760.0; // make an int16 range
		break;
	case 2: // euclidean
		for (int i = 0; i < n; i++) {
			sum += inp[i] * inp[i];
		}
		sum = std::sqrt(sum);
		break;
	default: // p-norm (euclidean is p-norm p=2)
		for (int i = 0; i < n; i++) {
			sum += std::pow(std::abs(inp[i]), embd_norm);
		}
		sum = std::pow(sum, 1.0 / embd_norm);
		break;
	}

	const float norm = sum > 0.0 ? 1.0 / sum : 0.0f;

	for (int i = 0; i < n; i++) {
		out[i] = inp[i] * norm;
	}
}

// from llama.cpp's example/embedding
static void embd_batch_decode(llama_context *ctx, llama_batch &batch,
		float *output, int n_seq, int n_embd, int embd_norm) {
	const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
	const struct llama_model *model = llama_get_model(ctx);

	// clear previous kv_cache values (irrelevant for embeddings)
	llama_memory_t memory = llama_get_memory(ctx);
	llama_memory_clear(memory, true);

	// run model
//    LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
	if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
		// encoder-only model
		if (llama_encode(ctx, batch) < 0) {
//            LOG_ERR("%s : failed to encode\n", __func__);
		}
	} else if (!llama_model_has_encoder(model)
			&& llama_model_has_decoder(model)) {
		// decoder-only model
		if (llama_decode(ctx, batch) < 0) {
//            LOG_ERR("%s : failed to decode\n", __func__);
		}
	}

	for (int i = 0; i < batch.n_tokens; i++) {
		if (!batch.logits[i]) {
			continue;
		}

		const float *embd = nullptr;
		int embd_pos = 0;

		if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
			// try to get token embeddings
			embd = llama_get_embeddings_ith(ctx, i);
			embd_pos = i;
			GGML_ASSERT(embd != NULL && "failed to get token embeddings");
		} else {
			// try to get sequence embeddings - supported only when pooling_type is not NONE
			embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
			embd_pos = batch.seq_id[i][0];
			GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
		}

		float *out = output + embd_pos * n_embd;
		embd_normalize(embd, out, n_embd, embd_norm);
	}
}

JNIEXPORT void JNICALL Java_org_argeo_jjml_llm_LlamaCppEmbeddingProcessor_doProcessEmbeddings(
		JNIEnv *env, jclass, jlong contextPointer, jobjectArray tokenLists,
		jfloatArray res) {
	auto *ctx = argeo::jni::as_pointer<llama_context*>(contextPointer);

	// TODO deal with normalization
	int embd_normalize = -1;

	int n_embd = llama_model_n_embd(llama_get_model(ctx));
	int n_batch = llama_n_batch(ctx);
	const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);

	struct llama_batch batch = llama_batch_init(n_batch, 0, 1);

	int n_prompts = env->GetArrayLength(tokenLists);

	jfloat *emb = (jfloat*) env->GetPrimitiveArrayCritical(res, nullptr);

	// break into batches
	int e = 0; // number of embeddings already stored
	int s = 0; // number of prompts in current batch
	for (int k = 0; k < n_prompts; k++) {
		jintArray tokenList = (jintArray) env->GetObjectArrayElement(tokenLists,
				k);

		const uint64_t n_toks = env->GetArrayLength(tokenList);

		// encode if at capacity
		if (batch.n_tokens + n_toks > n_batch) {
			float *out = emb + e * n_embd;
			embd_batch_decode(ctx, batch, out, s, n_embd, embd_normalize);
			e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
			s = 0;
			jjml_llm_batch_clear(batch);
		}

		// add to batch
//		embd_batch_add_seq(batch, inp, s);
		size_t n_tokens = env->GetArrayLength(tokenList);
		int *tokens = (int*) env->GetPrimitiveArrayCritical(tokenList, nullptr);
		for (size_t i = 0; i < n_tokens; i++) {
			jjml_llm_batch_add(batch, tokens[i], i, { s }, true);
		}
		env->ReleasePrimitiveArrayCritical(tokenList, tokens, 0);

		s += 1;
	}

	// final batch
	float *out = emb + e * n_embd;
	embd_batch_decode(ctx, batch, out, s, n_embd, embd_normalize);

	env->ReleasePrimitiveArrayCritical(res, emb, 0);
}