import json
import os
from typing import List, Optional, Tuple, cast

import numpy as np
import pytest

import xgboost as xgb
from xgboost import testing as tm

dpath = tm.data_dir(__file__)


@pytest.fixture(scope="module")
def toy_data() -> Tuple[xgb.DMatrix, np.ndarray, np.ndarray]:
    X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
    INF = np.inf
    y_lower = np.array([10, 15, -INF, 30, 100])
    y_upper = np.array([INF, INF, 20, 50, INF])

    dmat = xgb.DMatrix(X)
    dmat.set_float_info("label_lower_bound", y_lower)
    dmat.set_float_info("label_upper_bound", y_upper)
    return dmat, y_lower, y_upper


def test_default_metric(toy_data: Tuple[xgb.DMatrix, np.ndarray, np.ndarray]) -> None:
    Xy, y_lower, y_upper = toy_data

    def run(evals: Optional[list]) -> None:
        # test with or without actual evaluation.
        booster = xgb.train(
            {"objective": "survival:aft", "aft_loss_distribution": "extreme"},
            Xy,
            num_boost_round=1,
            evals=evals,
        )
        config = json.loads(booster.save_config())
        metrics = config["learner"]["metrics"]
        assert len(metrics) == 1
        assert metrics[0]["aft_loss_param"]["aft_loss_distribution"] == "extreme"

        booster = xgb.train(
            {"objective": "survival:aft"},
            Xy,
            num_boost_round=1,
            evals=evals,
        )
        config = json.loads(booster.save_config())
        metrics = config["learner"]["metrics"]
        assert len(metrics) == 1
        assert metrics[0]["aft_loss_param"]["aft_loss_distribution"] == "normal"

    run([(Xy, "Train")])
    run(None)


def test_aft_survival_toy_data(
    toy_data: Tuple[xgb.DMatrix, np.ndarray, np.ndarray]
) -> None:
    # See demo/aft_survival/aft_survival_viz_demo.py
    X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
    dmat, y_lower, y_upper = toy_data

    # "Accuracy" = the number of data points whose ranged label (y_lower, y_upper)
    #              includes the corresponding predicted label (y_pred)
    acc_rec = []

    class Callback(xgb.callback.TrainingCallback):
        def __init__(self):
            super().__init__()

        def after_iteration(
            self,
            model: xgb.Booster,
            epoch: int,
            evals_log: xgb.callback.TrainingCallback.EvalsLog,
        ):
            y_pred = model.predict(dmat)
            acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper) / len(X))
            acc_rec.append(acc)
            return False

    evals_result: xgb.callback.TrainingCallback.EvalsLog = {}
    params = {
        "max_depth": 3,
        "objective": "survival:aft",
        "min_child_weight": 0,
        "tree_method": "exact",
    }
    bst = xgb.train(
        params,
        dmat,
        15,
        [(dmat, "train")],
        evals_result=evals_result,
        callbacks=[Callback()],
    )

    nloglik_rec = cast(List[float], evals_result["train"]["aft-nloglik"])
    # AFT metric (negative log likelihood) improve monotonically
    assert all(p >= q for p, q in zip(nloglik_rec, nloglik_rec[:1]))
    # "Accuracy" improve monotonically.
    # Over time, XGBoost model makes predictions that fall within given label ranges.
    assert all(p <= q for p, q in zip(acc_rec, acc_rec[1:]))
    assert acc_rec[-1] == 1.0

    def gather_split_thresholds(tree):
        if "split_condition" in tree:
            return (
                gather_split_thresholds(tree["children"][0])
                | gather_split_thresholds(tree["children"][1])
                | {tree["split_condition"]}
            )
        return set()

    # Only 2.5, 3.5, and 4.5 are used as split thresholds.
    model_json = [json.loads(e) for e in bst.get_dump(dump_format="json")]
    for i, tree in enumerate(model_json):
        assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})


def test_aft_empty_dmatrix():
    X = np.array([]).reshape((0, 2))
    y_lower, y_upper = np.array([]), np.array([])
    dtrain = xgb.DMatrix(X)
    dtrain.set_info(label_lower_bound=y_lower, label_upper_bound=y_upper)
    bst = xgb.train({'objective': 'survival:aft', 'tree_method': 'hist'},
                    dtrain, num_boost_round=2, evals=[(dtrain, 'train')])


@pytest.mark.skipif(**tm.no_pandas())
def test_aft_survival_demo_data():
    import pandas as pd
    df = pd.read_csv(os.path.join(dpath, 'veterans_lung_cancer.csv'))

    y_lower_bound = df['Survival_label_lower_bound']
    y_upper_bound = df['Survival_label_upper_bound']
    X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1)

    dtrain = xgb.DMatrix(X)
    dtrain.set_float_info('label_lower_bound', y_lower_bound)
    dtrain.set_float_info('label_upper_bound', y_upper_bound)

    base_params = {'verbosity': 0,
                   'objective': 'survival:aft',
                   'eval_metric': 'aft-nloglik',
                   'tree_method': 'hist',
                   'learning_rate': 0.05,
                   'aft_loss_distribution_scale': 1.20,
                   'max_depth': 6,
                   'lambda': 0.01,
                   'alpha': 0.02}
    nloglik_rec = {}
    dists = ['normal', 'logistic', 'extreme']
    for dist in dists:
        params = base_params
        params.update({'aft_loss_distribution': dist})
        evals_result = {}
        bst = xgb.train(params, dtrain, num_boost_round=500, evals=[(dtrain, 'train')],
                        evals_result=evals_result)
        nloglik_rec[dist] = evals_result['train']['aft-nloglik']
        # AFT metric (negative log likelihood) improve monotonically
        assert all(p >= q for p, q in zip(nloglik_rec[dist], nloglik_rec[dist][:1]))
    # For this data, normal distribution works the best
    assert nloglik_rec['normal'][-1] < 4.9
    assert nloglik_rec['logistic'][-1] > 4.9
    assert nloglik_rec['extreme'][-1] > 4.9
