1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
import copy
import json
import os
import urllib.request
import zipfile
import generate_models as gm
import pytest
import xgboost
from xgboost import testing as tm
def run_model_param_check(config):
assert config['learner']['learner_model_param']['num_feature'] == str(4)
assert config['learner']['learner_train_param']['booster'] == 'gbtree'
def run_booster_check(booster, name):
config = json.loads(booster.save_config())
run_model_param_check(config)
if name.find('cls') != -1:
assert (len(booster.get_dump()) == gm.kForests * gm.kRounds *
gm.kClasses)
assert float(
config['learner']['learner_model_param']['base_score']) == 0.5
assert config['learner']['learner_train_param'][
'objective'] == 'multi:softmax'
elif name.find('logitraw') != -1:
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
assert config['learner']['learner_model_param']['num_class'] == str(0)
assert config['learner']['learner_train_param']['objective'] == 'binary:logitraw'
elif name.find('logit') != -1:
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
assert config['learner']['learner_model_param']['num_class'] == str(0)
assert config['learner']['learner_train_param'][
'objective'] == 'binary:logistic'
elif name.find('ltr') != -1:
assert config['learner']['learner_train_param'][
'objective'] == 'rank:ndcg'
else:
assert name.find('reg') != -1
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
assert float(
config['learner']['learner_model_param']['base_score']) == 0.5
assert config['learner']['learner_train_param'][
'objective'] == 'reg:squarederror'
def run_scikit_model_check(name, path):
if name.find('reg') != -1:
reg = xgboost.XGBRegressor()
reg.load_model(path)
config = json.loads(reg.get_booster().save_config())
if name.find('0.90') != -1:
assert config['learner']['learner_train_param'][
'objective'] == 'reg:linear'
else:
assert config['learner']['learner_train_param'][
'objective'] == 'reg:squarederror'
assert (len(reg.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
run_model_param_check(config)
elif name.find('cls') != -1:
cls = xgboost.XGBClassifier()
cls.load_model(path)
if name.find('0.90') == -1:
assert len(cls.classes_) == gm.kClasses
assert cls.n_classes_ == gm.kClasses
assert (len(cls.get_booster().get_dump()) ==
gm.kRounds * gm.kForests * gm.kClasses), path
config = json.loads(cls.get_booster().save_config())
assert config['learner']['learner_train_param'][
'objective'] == 'multi:softprob', path
run_model_param_check(config)
elif name.find('ltr') != -1:
ltr = xgboost.XGBRanker()
ltr.load_model(path)
assert (len(ltr.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
config = json.loads(ltr.get_booster().save_config())
assert config['learner']['learner_train_param'][
'objective'] == 'rank:ndcg'
run_model_param_check(config)
elif name.find('logitraw') != -1:
logit = xgboost.XGBClassifier()
logit.load_model(path)
assert (len(logit.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
config = json.loads(logit.get_booster().save_config())
assert config['learner']['learner_train_param']['objective'] == 'binary:logitraw'
elif name.find('logit') != -1:
logit = xgboost.XGBClassifier()
logit.load_model(path)
assert (len(logit.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
config = json.loads(logit.get_booster().save_config())
assert config['learner']['learner_train_param'][
'objective'] == 'binary:logistic'
else:
assert False
@pytest.mark.skipif(**tm.no_sklearn())
def test_model_compatibility():
"""Test model compatibility, can only be run on CI as others don't
have the credentials.
"""
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, "models")
if not os.path.exists(path):
zip_path, _ = urllib.request.urlretrieve(
"https://xgboost-ci-jenkins-artifacts.s3-us-west-2"
+ ".amazonaws.com/xgboost_model_compatibility_test.zip"
)
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(path)
models = [
os.path.join(root, f)
for root, subdir, files in os.walk(path)
for f in files
if f != "version"
]
assert models
for path in models:
name = os.path.basename(path)
if name.startswith("xgboost-"):
booster = xgboost.Booster(model_file=path)
run_booster_check(booster, name)
# Do full serialization.
booster = copy.copy(booster)
run_booster_check(booster, name)
elif name.startswith("xgboost_scikit"):
run_scikit_model_check(name, path)
else:
assert False
|