File: test_vision_api.py

package info (click to toggle)
llama.cpp 6641%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 43,824 kB
  • sloc: cpp: 218,020; ansic: 117,624; python: 29,020; lisp: 9,094; sh: 5,776; objc: 1,045; javascript: 828; xml: 259; makefile: 219
file content (160 lines) | stat: -rw-r--r-- 6,314 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import pytest
from utils import *
import base64
import requests

server: ServerProcess

def get_img_url(id: str) -> str:
    IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
    IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
    if id == "IMG_URL_0":
        return IMG_URL_0
    elif id == "IMG_URL_1":
        return IMG_URL_1
    elif id == "IMG_BASE64_URI_0":
        response = requests.get(IMG_URL_0)
        response.raise_for_status() # Raise an exception for bad status codes
        return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
    elif id == "IMG_BASE64_0":
        response = requests.get(IMG_URL_0)
        response.raise_for_status() # Raise an exception for bad status codes
        return base64.b64encode(response.content).decode("utf-8")
    elif id == "IMG_BASE64_URI_1":
        response = requests.get(IMG_URL_1)
        response.raise_for_status() # Raise an exception for bad status codes
        return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
    elif id == "IMG_BASE64_1":
        response = requests.get(IMG_URL_1)
        response.raise_for_status() # Raise an exception for bad status codes
        return base64.b64encode(response.content).decode("utf-8")
    else:
        return id

JSON_MULTIMODAL_KEY = "multimodal_data"
JSON_PROMPT_STRING_KEY = "prompt_string"

@pytest.fixture(autouse=True)
def create_server():
    global server
    server = ServerPreset.tinygemma3()

def test_models_supports_multimodal_capability():
    global server
    server.start()
    res = server.make_request("GET", "/models", data={})
    assert res.status_code == 200
    model_info = res.body["models"][0]
    print(model_info)
    assert "completion" in model_info["capabilities"]
    assert "multimodal" in model_info["capabilities"]

def test_v1_models_supports_multimodal_capability():
    global server
    server.start()
    res = server.make_request("GET", "/v1/models", data={})
    assert res.status_code == 200
    model_info = res.body["models"][0]
    print(model_info)
    assert "completion" in model_info["capabilities"]
    assert "multimodal" in model_info["capabilities"]

@pytest.mark.parametrize(
    "prompt, image_url, success, re_content",
    [
        # test model is trained on CIFAR-10, but it's quite dumb due to small size
        ("What is this:\n", "IMG_URL_0",              True, "(cat)+"),
        ("What is this:\n", "IMG_BASE64_URI_0",       True, "(cat)+"),
        ("What is this:\n", "IMG_URL_1",              True, "(frog)+"),
        ("Test test\n",     "IMG_URL_1",              True, "(frog)+"), # test invalidate cache
        ("What is this:\n", "malformed",              False, None),
        ("What is this:\n", "https://google.com/404", False, None), # non-existent image
        ("What is this:\n", "https://ggml.ai",        False, None), # non-image data
        # TODO @ngxson : test with multiple images, no images and with audio
    ]
)
def test_vision_chat_completion(prompt, image_url, success, re_content):
    global server
    server.start()
    res = server.make_request("POST", "/chat/completions", data={
        "temperature": 0.0,
        "top_k": 1,
        "messages": [
            {"role": "user", "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {
                    "url": get_img_url(image_url),
                }},
            ]},
        ],
    })
    if success:
        assert res.status_code == 200
        choice = res.body["choices"][0]
        assert "assistant" == choice["message"]["role"]
        assert match_regex(re_content, choice["message"]["content"])
    else:
        assert res.status_code != 200


@pytest.mark.parametrize(
    "prompt, image_data, success, re_content",
    [
        # test model is trained on CIFAR-10, but it's quite dumb due to small size
        ("What is this: <__media__>\n", "IMG_BASE64_0",         True, "(cat)+"),
        ("What is this: <__media__>\n", "IMG_BASE64_1",         True, "(frog)+"),
        ("What is this: <__media__>\n", "malformed",            False, None), # non-image data
        ("What is this:\n",             "",                     False, None), # empty string
    ]
)
def test_vision_completion(prompt, image_data, success, re_content):
    global server
    server.start()
    res = server.make_request("POST", "/completions", data={
        "temperature": 0.0,
        "top_k": 1,
        "prompt": {
            JSON_PROMPT_STRING_KEY: prompt,
            JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
        },
    })
    if success:
        assert res.status_code == 200
        content = res.body["content"]
        assert match_regex(re_content, content)
    else:
        assert res.status_code != 200


@pytest.mark.parametrize(
    "prompt, image_data, success",
    [
        # test model is trained on CIFAR-10, but it's quite dumb due to small size
        ("What is this: <__media__>\n", "IMG_BASE64_0",         True),
        ("What is this: <__media__>\n", "IMG_BASE64_1",         True),
        ("What is this: <__media__>\n", "malformed",            False), # non-image data
        ("What is this:\n",             "base64",               False), # non-image data
    ]
)
def test_vision_embeddings(prompt, image_data, success):
    global server
    server.server_embeddings = True
    server.n_batch = 512
    server.start()
    image_data = get_img_url(image_data)
    res = server.make_request("POST", "/embeddings", data={
        "content": [
            { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
            { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
            { JSON_PROMPT_STRING_KEY: prompt, },
        ],
    })
    if success:
        assert res.status_code == 200
        content = res.body
        # Ensure embeddings are stable when multimodal.
        assert content[0]['embedding'] == content[1]['embedding']
        # Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
        assert content[0]['embedding'] != content[2]['embedding']
    else:
        assert res.status_code != 200