File: test_vision_api.py

package info (click to toggle)
llama.cpp 5882%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 34,020 kB
  • sloc: cpp: 189,548; ansic: 115,889; python: 24,977; objc: 6,050; lisp: 5,741; sh: 5,571; makefile: 1,293; javascript: 807; xml: 259
file content (60 lines) | stat: -rw-r--r-- 2,326 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pytest
from utils import *
import base64
import requests

server: ServerProcess

IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"

response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")


@pytest.fixture(autouse=True)
def create_server():
    global server
    server = ServerPreset.tinygemma3()


@pytest.mark.parametrize(
    "prompt, image_url, success, re_content",
    [
        # test model is trained on CIFAR-10, but it's quite dumb due to small size
        ("What is this:\n", IMG_URL_0,                True, "(cat)+"),
        ("What is this:\n", "IMG_BASE64_0",           True, "(cat)+"), # exceptional, so that we don't cog up the log
        ("What is this:\n", IMG_URL_1,                True, "(frog)+"),
        ("Test test\n",     IMG_URL_1,                True, "(frog)+"), # test invalidate cache
        ("What is this:\n", "malformed",              False, None),
        ("What is this:\n", "https://google.com/404", False, None), # non-existent image
        ("What is this:\n", "https://ggml.ai",        False, None), # non-image data
        # TODO @ngxson : test with multiple images, no images and with audio
    ]
)
def test_vision_chat_completion(prompt, image_url, success, re_content):
    global server
    server.start(timeout_seconds=60) # vision model may take longer to load due to download size
    if image_url == "IMG_BASE64_0":
        image_url = IMG_BASE64_0
    res = server.make_request("POST", "/chat/completions", data={
        "temperature": 0.0,
        "top_k": 1,
        "messages": [
            {"role": "user", "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {
                    "url": image_url,
                }},
            ]},
        ],
    })
    if success:
        assert res.status_code == 200
        choice = res.body["choices"][0]
        assert "assistant" == choice["message"]["role"]
        assert match_regex(re_content, choice["message"]["content"])
    else:
        assert res.status_code != 200