from math import log

import pytest

import wn
from wn import similarity as sim
from wn.ic import information_content as infocont
from wn.taxonomy import taxonomy_depth


def get_synsets(w):
    return {
        "information": w.synset("test-en-0001-n"),
        "example": w.synset("test-en-0002-n"),
        "sample": w.synset("test-en-0004-n"),
        "random sample": w.synset("test-en-0005-n"),
        "random sample2": w.synset("test-en-0008-n"),
        "datum": w.synset("test-en-0006-n"),
        "exemplify": w.synset("test-en-0003-v"),
    }


# some fake information content; computed using:
#   words = ['example', 'example', 'sample', 'random sample', 'illustrate']
#   ic = compute(words, wn.Wordnet('test-en'), distribute_weight=False)

ic = {
    "n": {
        "test-en-0001-n": 5.0,  # information
        "test-en-0002-n": 5.0,  # example, illustration
        "test-en-0004-n": 3.0,  # sample
        "test-en-0005-n": 2.0,  # random sample
        "test-en-0008-n": 2.0,  # random sample 2
        "test-en-0006-n": 1.0,  # datum
        None: 6.0,
    },
    "v": {
        "test-en-0003-v": 2.0,  # exemplify, illustrate
        "test-en-0007-v": 1.0,  # resignate
        None: 2.0,
    },
    "a": {None: 1.0},
    "r": {None: 1.0},
}


@pytest.mark.usefixtures("mini_db")
def test_path():
    ss = get_synsets(wn.Wordnet("test-en"))
    assert sim.path(ss["information"], ss["information"]) == 1 / 1
    assert sim.path(ss["information"], ss["example"]) == 1 / 2
    assert sim.path(ss["information"], ss["sample"]) == 1 / 3
    assert sim.path(ss["information"], ss["random sample"]) == 1 / 4
    assert sim.path(ss["random sample"], ss["datum"]) == 1 / 5
    assert sim.path(ss["random sample2"], ss["datum"]) == 0
    assert sim.path(ss["random sample2"], ss["datum"], simulate_root=True) == 1 / 4
    assert (
        sim.path(ss["random sample"], ss["random sample2"], simulate_root=True) == 1 / 6
    )
    with pytest.raises(wn.Error):
        sim.path(ss["example"], ss["exemplify"])
    with pytest.raises(wn.Error):
        sim.wup(ss["example"], ss["exemplify"], simulate_root=True)


@pytest.mark.usefixtures("mini_db")
def test_wup():
    ss = get_synsets(wn.Wordnet("test-en"))
    assert sim.wup(ss["information"], ss["information"]) == (2 * 1) / (0 + 0 + 2 * 1)
    assert sim.wup(ss["information"], ss["example"]) == (2 * 1) / (0 + 1 + 2 * 1)
    assert sim.wup(ss["information"], ss["sample"]) == (2 * 1) / (0 + 2 + 2 * 1)
    assert sim.wup(ss["information"], ss["random sample"]) == (2 * 1) / (0 + 3 + 2 * 1)
    assert sim.wup(ss["random sample"], ss["datum"]) == (2 * 1) / (3 + 1 + 2 * 1)
    with pytest.raises(wn.Error):
        assert sim.wup(ss["random sample2"], ss["datum"])
    assert sim.wup(ss["random sample2"], ss["datum"], simulate_root=True) == (2 * 1) / (
        1 + 2 + 2 * 1
    )
    assert sim.wup(ss["random sample"], ss["random sample2"], simulate_root=True) == (
        2 * 1
    ) / (4 + 1 + 2 * 1)
    with pytest.raises(wn.Error):
        sim.wup(ss["example"], ss["exemplify"])
    with pytest.raises(wn.Error):
        sim.wup(ss["example"], ss["exemplify"], simulate_root=True)


@pytest.mark.usefixtures("mini_db")
def test_lch():
    w = wn.Wordnet("test-en")
    ss = get_synsets(w)
    d_n = taxonomy_depth(w, "n")
    assert sim.lch(ss["information"], ss["information"], d_n) == -log(
        (0 + 1) / (2 * d_n)
    )
    assert sim.lch(ss["information"], ss["example"], d_n) == -log((1 + 1) / (2 * d_n))
    assert sim.lch(ss["information"], ss["sample"], d_n) == -log((2 + 1) / (2 * d_n))
    assert sim.lch(ss["information"], ss["random sample"], d_n) == -log(
        (3 + 1) / (2 * d_n)
    )
    assert sim.lch(ss["random sample"], ss["datum"], d_n) == -log((4 + 1) / (2 * d_n))
    with pytest.raises(wn.Error):
        assert sim.lch(ss["random sample2"], ss["datum"], d_n)
    assert sim.lch(ss["random sample2"], ss["datum"], d_n, simulate_root=True) == -log(
        (3 + 1) / (2 * d_n)
    )
    assert sim.lch(
        ss["random sample"], ss["random sample2"], d_n, simulate_root=True
    ) == -log((5 + 1) / (2 * d_n))
    with pytest.raises(wn.Error):
        sim.lch(ss["example"], ss["exemplify"], d_n)
    with pytest.raises(wn.Error):
        sim.lch(ss["example"], ss["exemplify"], d_n, simulate_root=True)


@pytest.mark.usefixtures("mini_db")
def test_res():
    w = wn.Wordnet("test-en")
    ss = get_synsets(w)
    assert sim.res(ss["information"], ss["information"], ic) == infocont(
        ss["information"], ic
    )
    assert sim.res(ss["information"], ss["example"], ic) == infocont(
        ss["information"], ic
    )
    assert sim.res(ss["information"], ss["sample"], ic) == infocont(
        ss["information"], ic
    )
    assert sim.res(ss["information"], ss["random sample"], ic) == infocont(
        ss["information"], ic
    )
    assert sim.res(ss["random sample"], ss["datum"], ic) == infocont(
        ss["information"], ic
    )
    with pytest.raises(wn.Error):
        sim.res(ss["random sample2"], ss["datum"], ic)
    with pytest.raises(wn.Error):
        sim.res(ss["example"], ss["exemplify"], ic)


@pytest.mark.usefixtures("mini_db")
def test_jcn():
    w = wn.Wordnet("test-en")
    ss = get_synsets(w)
    info_ic = infocont(ss["information"], ic)
    assert sim.jcn(ss["information"], ss["information"], ic) == float("inf")
    assert sim.jcn(ss["information"], ss["example"], ic) == float("inf")
    assert sim.jcn(ss["information"], ss["sample"], ic) == 1 / (
        (info_ic + infocont(ss["sample"], ic)) - 2 * info_ic
    )
    assert sim.jcn(ss["information"], ss["random sample"], ic) == 1 / (
        (info_ic + infocont(ss["random sample"], ic)) - 2 * info_ic
    )
    assert sim.jcn(ss["random sample"], ss["datum"], ic) == 1 / (
        (infocont(ss["random sample"], ic) + infocont(ss["datum"], ic)) - 2 * info_ic
    )
    with pytest.raises(wn.Error):
        sim.jcn(ss["random sample2"], ss["datum"], ic)
    with pytest.raises(wn.Error):
        sim.jcn(ss["example"], ss["exemplify"], ic)


@pytest.mark.usefixtures("mini_db")
def test_lin():
    w = wn.Wordnet("test-en")
    ss = get_synsets(w)
    info_ic = infocont(ss["information"], ic)
    assert sim.lin(ss["information"], ss["information"], ic) == 1.0
    assert sim.lin(ss["information"], ss["example"], ic) == 1.0
    assert sim.lin(ss["information"], ss["sample"], ic) == (2 * info_ic) / (
        info_ic + infocont(ss["sample"], ic)
    )
    assert sim.lin(ss["information"], ss["random sample"], ic) == (2 * info_ic) / (
        info_ic + infocont(ss["random sample"], ic)
    )
    assert sim.lin(ss["random sample"], ss["datum"], ic) == (
        (2 * info_ic) / (infocont(ss["random sample"], ic) + infocont(ss["datum"], ic))
    )
    with pytest.raises(wn.Error):
        sim.lin(ss["random sample2"], ss["datum"], ic)
    with pytest.raises(wn.Error):
        sim.lin(ss["example"], ss["exemplify"], ic)
