from collections import OrderedDict

import numpy as np
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal

from upsetplot import from_contents, from_indicators, from_memberships, generate_data


@pytest.mark.parametrize("typ", [set, list, tuple, iter])
def test_from_memberships_no_data(typ):
    with pytest.raises(ValueError, match="at least one category"):
        from_memberships([])
    with pytest.raises(ValueError, match="at least one category"):
        from_memberships([[], []])
    with pytest.raises(ValueError, match="strings"):
        from_memberships([[1]])
    with pytest.raises(ValueError, match="strings"):
        from_memberships([[1, "str"]])
    with pytest.raises(TypeError):
        from_memberships([1])

    out = from_memberships(
        [
            typ([]),
            typ(["hello"]),
            typ(["world"]),
            typ(["hello", "world"]),
        ]
    )
    exp = pd.DataFrame(
        [[False, False, 1], [True, False, 1], [False, True, 1], [True, True, 1]],
        columns=["hello", "world", "ones"],
    ).set_index(["hello", "world"])["ones"]
    assert isinstance(exp.index, pd.MultiIndex)
    assert_series_equal(exp, out)

    # test sorting by name
    out = from_memberships([typ(["hello"]), typ(["world"])])
    exp = pd.DataFrame(
        [[True, False, 1], [False, True, 1]], columns=["hello", "world", "ones"]
    ).set_index(["hello", "world"])["ones"]
    assert_series_equal(exp, out)
    out = from_memberships([typ(["world"]), typ(["hello"])])
    exp = pd.DataFrame(
        [[False, True, 1], [True, False, 1]], columns=["hello", "world", "ones"]
    ).set_index(["hello", "world"])["ones"]
    assert_series_equal(exp, out)


@pytest.mark.parametrize(
    ("data", "ndim"),
    [
        ([1, 2, 3, 4], 1),
        (np.array([1, 2, 3, 4]), 1),
        (pd.Series([1, 2, 3, 4], name="foo"), 1),
        ([[1, "a"], [2, "b"], [3, "c"], [4, "d"]], 2),
        (
            pd.DataFrame(
                [[1, "a"], [2, "b"], [3, "c"], [4, "d"]],
                columns=["foo", "bar"],
                index=["q", "r", "s", "t"],
            ),
            2,
        ),
    ],
)
def test_from_memberships_with_data(data, ndim):
    memberships = [[], ["hello"], ["world"], ["hello", "world"]]
    out = from_memberships(memberships, data=data)
    assert out is not data  # make sure frame is copied
    if hasattr(data, "loc") and np.asarray(data).dtype.kind in "ifb":
        # but not deepcopied when possible
        assert out.values.base is np.asarray(data).base
    if ndim == 1:
        assert isinstance(out, pd.Series)
    else:
        assert isinstance(out, pd.DataFrame)
    assert_frame_equal(
        pd.DataFrame(out).reset_index(drop=True),
        pd.DataFrame(data).reset_index(drop=True),
    )
    no_data = from_memberships(memberships=memberships)
    assert_index_equal(out.index, no_data.index)

    with pytest.raises(ValueError, match="length"):
        from_memberships(memberships[:-1], data=data)


@pytest.mark.parametrize(
    "data", [None, {"attr1": [3, 4, 5, 6, 7, 8], "attr2": list("qrstuv")}]
)
@pytest.mark.parametrize("typ", [set, list, tuple, iter])
@pytest.mark.parametrize("id_column", ["id", "blah"])
def test_from_contents_vs_memberships(data, typ, id_column):
    contents = OrderedDict(
        [
            ("cat1", typ(["aa", "bb", "cc"])),
            ("cat2", typ(["cc", "dd"])),
            ("cat3", typ(["ee"])),
        ]
    )
    # Note that ff is not present in contents
    data_df = pd.DataFrame(data, index=["aa", "bb", "cc", "dd", "ee", "ff"])
    baseline = from_contents(contents, data=data_df, id_column=id_column)
    # compare from_contents to from_memberships
    expected = from_memberships(
        memberships=[{"cat1"}, {"cat1"}, {"cat1", "cat2"}, {"cat2"}, {"cat3"}, []],
        data=data_df,
    )
    assert_series_equal(
        baseline[id_column].reset_index(drop=True),
        pd.Series(["aa", "bb", "cc", "dd", "ee", "ff"], name=id_column),
    )
    baseline_without_id = baseline.drop([id_column], axis=1)
    assert_frame_equal(
        baseline_without_id,
        expected,
        check_column_type=baseline_without_id.shape[1] > 0,
    )


def test_from_contents(typ=set, id_column="id"):
    contents = OrderedDict(
        [("cat1", {"aa", "bb", "cc"}), ("cat2", {"cc", "dd"}), ("cat3", {"ee"})]
    )
    empty_data = pd.DataFrame(index=["aa", "bb", "cc", "dd", "ee"])
    baseline = from_contents(contents, data=empty_data, id_column=id_column)
    # data=None
    out = from_contents(contents, id_column=id_column)
    assert_frame_equal(out.sort_values(id_column), baseline)

    # unordered contents dict
    out = from_contents(
        {"cat3": contents["cat3"], "cat2": contents["cat2"], "cat1": contents["cat1"]},
        data=empty_data,
        id_column=id_column,
    )
    assert_frame_equal(out.reorder_levels(["cat1", "cat2", "cat3"]), baseline)

    # empty category
    out = from_contents(
        {
            "cat1": contents["cat1"],
            "cat2": contents["cat2"],
            "cat3": contents["cat3"],
            "cat4": [],
        },
        data=empty_data,
        id_column=id_column,
    )
    assert not out.index.to_frame()["cat4"].any()  # cat4 should be all-false
    assert len(out.index.names) == 4
    out.index = out.index.to_frame().set_index(["cat1", "cat2", "cat3"]).index
    assert_frame_equal(out, baseline)


@pytest.mark.parametrize("id_column", ["id", "blah"])
def test_from_contents_invalid(id_column):
    contents = OrderedDict(
        [("cat1", {"aa", "bb", "cc"}), ("cat2", {"cc", "dd"}), ("cat3", {"ee"})]
    )
    with pytest.raises(ValueError, match="columns overlap"):
        from_contents(
            contents, data=pd.DataFrame({"cat1": [1, 2, 3, 4, 5]}), id_column=id_column
        )
    with pytest.raises(ValueError, match="duplicate ids"):
        from_contents({"cat1": ["aa", "bb"], "cat2": ["dd", "dd"]}, id_column=id_column)
    # category named id
    with pytest.raises(ValueError, match="cannot be named"):
        from_contents(
            {
                id_column: {"aa", "bb", "cc"},
                "cat2": {"cc", "dd"},
            },
            id_column=id_column,
        )
    # category named id
    with pytest.raises(ValueError, match="cannot contain"):
        from_contents(
            contents,
            data=pd.DataFrame(
                {id_column: [1, 2, 3, 4, 5]}, index=["aa", "bb", "cc", "dd", "ee"]
            ),
            id_column=id_column,
        )
    with pytest.raises(ValueError, match="identifiers in contents"):
        from_contents({"cat1": ["aa"]}, data=pd.DataFrame([[1]]), id_column=id_column)


@pytest.mark.parametrize(
    ("indicators", "data", "exc_type", "match"),
    [
        (["a", "b"], None, ValueError, "data must be provided"),
        (lambda df: [True, False, True], None, ValueError, "data must be provided"),
        (["a", "unknown_col"], {"a": [1, 2, 3]}, KeyError, "unknown_col"),
        (("a",), {"a": [1, 2, 3]}, ValueError, "tuple"),
        ({"cat1": [0, 1, 1]}, {"a": [1, 2, 3]}, ValueError, "must all be boolean"),
        (
            pd.DataFrame({"cat1": [True, False, True]}, index=["a", "b", "c"]),
            {"A": [1, 2, 3]},
            ValueError,
            "all its values must be present",
        ),
    ],
)
def test_from_indicators_invalid(indicators, data, exc_type, match):
    with pytest.raises(exc_type, match=match):
        from_indicators(indicators=indicators, data=data)


@pytest.mark.parametrize(
    "indicators",
    [
        pd.DataFrame({"cat1": [False, True, False]}),
        pd.DataFrame({"cat1": [False, True, False]}, dtype="O"),
        {"cat1": [False, True, False]},
        lambda data: {"cat1": {pd.DataFrame(data).index.values[1]: True}},
    ],
)
@pytest.mark.parametrize(
    "data",
    [
        pd.DataFrame({"val1": [3, 4, 5]}),
        pd.DataFrame({"val1": [3, 4, 5]}, index=["a", "b", "c"]),
        {"val1": [3, 4, 5]},
    ],
)
def test_from_indicators_equivalence(indicators, data):
    assert_frame_equal(
        from_indicators(indicators, data), from_memberships([[], ["cat1"], []], data)
    )


def test_generate_data_warning():
    with pytest.warns(DeprecationWarning):
        generate_data()
