File: test_config.py

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (76 lines) | stat: -rw-r--r-- 2,428 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import multiprocessing
from concurrent.futures import ThreadPoolExecutor

import pytest

import xgboost as xgb


@pytest.mark.parametrize("verbosity_level", [0, 1, 2, 3])
def test_global_config_verbosity(verbosity_level):
    def get_current_verbosity():
        return xgb.get_config()["verbosity"]

    old_verbosity = get_current_verbosity()
    assert old_verbosity == 1
    with xgb.config_context(verbosity=verbosity_level):
        new_verbosity = get_current_verbosity()
        assert new_verbosity == verbosity_level
    assert old_verbosity == get_current_verbosity()


@pytest.mark.parametrize("use_rmm", [False, True])
def test_global_config_use_rmm(use_rmm):
    def get_current_use_rmm_flag():
        return xgb.get_config()["use_rmm"]

    old_use_rmm_flag = get_current_use_rmm_flag()
    with xgb.config_context(use_rmm=use_rmm):
        new_use_rmm_flag = get_current_use_rmm_flag()
        assert new_use_rmm_flag == use_rmm
    assert old_use_rmm_flag == get_current_use_rmm_flag()


def test_nested_config() -> None:
    verbosity = xgb.get_config()["verbosity"]
    assert verbosity == 1

    with xgb.config_context(verbosity=3):
        assert xgb.get_config()["verbosity"] == 3
        with xgb.config_context(verbosity=2):
            assert xgb.get_config()["verbosity"] == 2
            with xgb.config_context(verbosity=1):
                assert xgb.get_config()["verbosity"] == 1
            assert xgb.get_config()["verbosity"] == 2
        assert xgb.get_config()["verbosity"] == 3

    with xgb.config_context(verbosity=3):
        assert xgb.get_config()["verbosity"] == 3
        with xgb.config_context(verbosity=None):
            assert xgb.get_config()["verbosity"] == 3  # None has no effect

    xgb.set_config(verbosity=2)
    assert xgb.get_config()["verbosity"] == 2
    with xgb.config_context(verbosity=3):
        assert xgb.get_config()["verbosity"] == 3
    xgb.set_config(verbosity=verbosity)  # reset

    verbosity = xgb.get_config()["verbosity"]
    assert verbosity == 1


def test_thread_safety():
    n_threads = multiprocessing.cpu_count()
    futures = []
    with ThreadPoolExecutor(max_workers=n_threads) as executor:
        for i in range(256):
            f = executor.submit(test_nested_config)
            futures.append(f)

    for f in futures:
        f.result()


def test_nthread() -> None:
    config = xgb.get_config()
    assert config["nthread"] == 0