File: test_pickling.py

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (63 lines) | stat: -rw-r--r-- 1,611 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import json
import os
import pickle

import numpy as np

import xgboost as xgb

kRows = 100
kCols = 10


def generate_data():
    X = np.random.randn(kRows, kCols)
    y = np.random.randn(kRows)
    return X, y


class TestPickling:
    def run_model_pickling(self, xgb_params) -> str:
        X, y = generate_data()
        dtrain = xgb.DMatrix(X, y)
        bst = xgb.train(xgb_params, dtrain)

        dump_0 = bst.get_dump(dump_format='json')
        assert dump_0
        config_0 = bst.save_config()

        filename = 'model.pkl'

        with open(filename, 'wb') as fd:
            pickle.dump(bst, fd)

        with open(filename, 'rb') as fd:
            bst = pickle.load(fd)

        with open(filename, 'wb') as fd:
            pickle.dump(bst, fd)

        with open(filename, 'rb') as fd:
            bst = pickle.load(fd)

        assert bst.get_dump(dump_format='json') == dump_0

        if os.path.exists(filename):
            os.remove(filename)

        config_1 = bst.save_config()
        assert config_0 == config_1
        return json.loads(config_0)

    def test_model_pickling_json(self):
        def check(config):
            tree_param = config["learner"]["gradient_booster"]["tree_train_param"]
            subsample = tree_param["subsample"]
            assert float(subsample) == 0.5

        params = {"nthread": 8, "tree_method": "hist", "subsample": 0.5}
        config = self.run_model_pickling(params)
        check(config)
        params = {"nthread": 8, "tree_method": "exact", "subsample": 0.5}
        config = self.run_model_pickling(params)
        check(config)