File: test_template.py

package info (click to toggle)
llama.cpp 5882%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 34,040 kB
  • sloc: cpp: 189,548; ansic: 115,889; python: 24,977; objc: 6,050; sh: 5,819; lisp: 5,741; makefile: 1,293; javascript: 807; xml: 259
file content (108 lines) | stat: -rw-r--r-- 4,121 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python
import pytest

# ensure grandparent path is in sys.path
from pathlib import Path
import sys

from unit.test_tool_call import TEST_TOOL
path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path))

import datetime
from utils import *

server: ServerProcess

TIMEOUT_SERVER_START = 15*60

@pytest.fixture(autouse=True)
def create_server():
    global server
    server = ServerPreset.tinyllama2()
    server.model_alias = "tinyllama-2"
    server.server_port = 8081
    server.n_slots = 1


@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
@pytest.mark.parametrize("template_name,reasoning_budget,expected_end", [
    ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", None, "<think>\n"),
    ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B",   -1, "<think>\n"),
    ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B",    0, "<think>\n</think>"),

    ("Qwen-Qwen3-0.6B", -1, "<|im_start|>assistant\n"),
    ("Qwen-Qwen3-0.6B",  0, "<|im_start|>assistant\n<think>\n\n</think>\n\n"),

    ("Qwen-QwQ-32B", -1, "<|im_start|>assistant\n<think>\n"),
    ("Qwen-QwQ-32B",  0, "<|im_start|>assistant\n<think>\n</think>"),

    ("CohereForAI-c4ai-command-r7b-12-2024-tool_use", -1, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"),
    ("CohereForAI-c4ai-command-r7b-12-2024-tool_use",  0, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|><|END_THINKING|>"),
])
def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expected_end: str, tools: list[dict]):
    global server
    server.jinja = True
    server.reasoning_budget = reasoning_budget
    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
    server.start(timeout_seconds=TIMEOUT_SERVER_START)

    res = server.make_request("POST", "/apply-template", data={
        "messages": [
            {"role": "user", "content": "What is today?"},
        ],
        "tools": tools,
    })
    assert res.status_code == 200
    prompt = res.body["prompt"]

    assert prompt.endswith(expected_end), f"Expected prompt to end with '{expected_end}', got '{prompt}'"


@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
@pytest.mark.parametrize("template_name,format", [
    ("meta-llama-Llama-3.3-70B-Instruct",    "%d %b %Y"),
    ("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
])
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
    global server
    server.jinja = True
    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
    server.start(timeout_seconds=TIMEOUT_SERVER_START)

    res = server.make_request("POST", "/apply-template", data={
        "messages": [
            {"role": "user", "content": "What is today?"},
        ],
        "tools": tools,
    })
    assert res.status_code == 200
    prompt = res.body["prompt"]

    today_str = datetime.date.today().strftime(format)
    assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"


@pytest.mark.parametrize("add_generation_prompt", [False, True])
@pytest.mark.parametrize("template_name,expected_generation_prompt", [
    ("meta-llama-Llama-3.3-70B-Instruct",    "<|start_header_id|>assistant<|end_header_id|>"),
])
def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
    global server
    server.jinja = True
    server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
    server.start(timeout_seconds=TIMEOUT_SERVER_START)

    res = server.make_request("POST", "/apply-template", data={
        "messages": [
            {"role": "user", "content": "What is today?"},
        ],
        "add_generation_prompt": add_generation_prompt,
    })
    assert res.status_code == 200
    prompt = res.body["prompt"]

    if add_generation_prompt:
        assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
    else:
        assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"