File: test_optimization_history.py

package info (click to toggle)
optuna 4.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 4,784 kB
  • sloc: python: 40,634; sh: 97; makefile: 30
file content (28 lines) | stat: -rw-r--r-- 1,210 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from __future__ import annotations

from io import BytesIO

import pytest

from optuna.visualization._optimization_history import _OptimizationHistoryInfo
from optuna.visualization.matplotlib._matplotlib_imports import plt
from optuna.visualization.matplotlib._optimization_history import _get_optimization_history_plot
from tests.visualization_tests.test_optimization_history import optimization_history_info_lists


@pytest.mark.parametrize("target_name", ["Objective Value", "Target Name"])
@pytest.mark.parametrize("info_list", optimization_history_info_lists)
def test_get_optimization_history_plot(
    target_name: str, info_list: list[_OptimizationHistoryInfo]
) -> None:
    figure = _get_optimization_history_plot(info_list, target_name=target_name)
    assert figure.get_ylabel() == target_name
    expected_legends = []
    for info in info_list:
        expected_legends.append(info.values_info.label_name)
        if info.best_values_info is not None:
            expected_legends.append(info.best_values_info.label_name)
    legends = [legend.get_text() for legend in figure.legend().get_texts()]
    assert sorted(legends) == sorted(expected_legends)
    plt.savefig(BytesIO())
    plt.close()