File: test_sycl_training_continuation.py

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (59 lines) | stat: -rw-r--r-- 2,131 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import xgboost as xgb
import json

rng = np.random.RandomState(1994)


class TestSYCLTrainingContinuation:
    def run_training_continuation(self, use_json):
        kRows = 64
        kCols = 32
        X = np.random.randn(kRows, kCols)
        y = np.random.randn(kRows)
        dtrain = xgb.DMatrix(X, y)
        params = {
            "device": "sycl",
            "max_depth": "2",
            "gamma": "0.1",
            "alpha": "0.01",
            "enable_experimental_json_serialization": use_json,
        }
        bst_0 = xgb.train(params, dtrain, num_boost_round=64)
        dump_0 = bst_0.get_dump(dump_format="json")

        bst_1 = xgb.train(params, dtrain, num_boost_round=32)
        bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
        dump_1 = bst_1.get_dump(dump_format="json")

        def recursive_compare(obj_0, obj_1):
            if isinstance(obj_0, float):
                assert np.isclose(obj_0, obj_1, atol=1e-6)
            elif isinstance(obj_0, str):
                assert obj_0 == obj_1
            elif isinstance(obj_0, int):
                assert obj_0 == obj_1
            elif isinstance(obj_0, dict):
                keys_0 = list(obj_0.keys())
                keys_1 = list(obj_1.keys())
                values_0 = list(obj_0.values())
                values_1 = list(obj_1.values())
                for i in range(len(obj_0.items())):
                    assert keys_0[i] == keys_1[i]
                    if list(obj_0.keys())[i] != "missing":
                        recursive_compare(values_0[i], values_1[i])
            else:
                for i in range(len(obj_0)):
                    recursive_compare(obj_0[i], obj_1[i])

        assert len(dump_0) == len(dump_1)
        for i in range(len(dump_0)):
            obj_0 = json.loads(dump_0[i])
            obj_1 = json.loads(dump_1[i])
            recursive_compare(obj_0, obj_1)

    def test_sycl_training_continuation_binary(self):
        self.run_training_continuation(False)

    def test_sycl_training_continuation_json(self):
        self.run_training_continuation(True)