File: test_rerank.py

package info (click to toggle)
llama.cpp 7593%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 71,012 kB
  • sloc: cpp: 329,391; ansic: 48,249; python: 32,103; lisp: 10,053; sh: 6,070; objc: 1,349; javascript: 924; xml: 384; makefile: 233
file content (146 lines) | stat: -rw-r--r-- 4,837 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import pytest
from utils import *

server = ServerPreset.jina_reranker_tiny()


@pytest.fixture(autouse=True)
def create_server():
    global server
    server = ServerPreset.jina_reranker_tiny()


TEST_DOCUMENTS = [
    "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
    "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
    "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
    "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]


def test_rerank():
    global server
    server.start()
    res = server.make_request("POST", "/rerank", data={
        "query": "Machine learning is",
        "documents": TEST_DOCUMENTS,
    })
    assert res.status_code == 200
    assert len(res.body["results"]) == 4

    most_relevant = res.body["results"][0]
    least_relevant = res.body["results"][0]
    for doc in res.body["results"]:
        if doc["relevance_score"] > most_relevant["relevance_score"]:
            most_relevant = doc
        if doc["relevance_score"] < least_relevant["relevance_score"]:
            least_relevant = doc

    assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
    assert most_relevant["index"] == 2
    assert least_relevant["index"] == 3


def test_rerank_tei_format():
    global server
    server.start()
    res = server.make_request("POST", "/rerank", data={
        "query": "Machine learning is",
        "texts": TEST_DOCUMENTS,
    })
    assert res.status_code == 200
    assert len(res.body) == 4

    most_relevant = res.body[0]
    least_relevant = res.body[0]
    for doc in res.body:
        if doc["score"] > most_relevant["score"]:
            most_relevant = doc
        if doc["score"] < least_relevant["score"]:
            least_relevant = doc

    assert most_relevant["score"] > least_relevant["score"]
    assert most_relevant["index"] == 2
    assert least_relevant["index"] == 3


@pytest.mark.parametrize("documents", [
    [],
    None,
    123,
    [1, 2, 3],
])
def test_invalid_rerank_req(documents):
    global server
    server.start()
    res = server.make_request("POST", "/rerank", data={
        "query": "Machine learning is",
        "documents": documents,
    })
    assert res.status_code == 400
    assert "error" in res.body


@pytest.mark.parametrize(
    "query,doc1,doc2,n_tokens",
    [
        ("Machine learning is", "A machine", "Learning is", 19),
        ("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
    ]
)
def test_rerank_usage(query, doc1, doc2, n_tokens):
    global server
    server.start()

    res = server.make_request("POST", "/rerank", data={
        "query": query,
        "documents": [
            doc1,
            doc2,
        ]
    })
    assert res.status_code == 200
    assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
    assert res.body['usage']['prompt_tokens'] == n_tokens


@pytest.mark.parametrize("top_n,expected_len", [
    (None, len(TEST_DOCUMENTS)),  # no top_n parameter
    (2, 2),
    (4, 4),
    (99, len(TEST_DOCUMENTS)),    # higher than available docs
])
def test_rerank_top_n(top_n, expected_len):
    global server
    server.start()
    data = {
        "query": "Machine learning is",
        "documents": TEST_DOCUMENTS,
    }
    if top_n is not None:
        data["top_n"] = top_n

    res = server.make_request("POST", "/rerank", data=data)
    assert res.status_code == 200
    assert len(res.body["results"]) == expected_len


@pytest.mark.parametrize("top_n,expected_len", [
    (None, len(TEST_DOCUMENTS)),  # no top_n parameter
    (2, 2),
    (4, 4),
    (99, len(TEST_DOCUMENTS)),    # higher than available docs
])
def test_rerank_tei_top_n(top_n, expected_len):
    global server
    server.start()
    data = {
        "query": "Machine learning is",
        "texts": TEST_DOCUMENTS,
    }
    if top_n is not None:
        data["top_n"] = top_n

    res = server.make_request("POST", "/rerank", data=data)
    assert res.status_code == 200
    assert len(res.body) == expected_len