import pytest
from .testutils import compare_data, tol


@pytest.fixture
def model_pump(neuron_instance):
    """A simple rxd implementation of the sodium-potassium pump"""

    h, rxd, data, save_path = neuron_instance
    dend = h.Section(name="dend")
    dend.diam = 2
    dend.nseg = 101
    dend.L = 100

    cyt = rxd.Region(dend, name="cyt", nrn_region="i")
    ecs = rxd.Region(dend, name="ecs", nrn_region="o")
    mem = rxd.Region(dend, name="mem", geometry=rxd.membrane())
    na = rxd.Species(
        [cyt, ecs],
        name="na",
        charge=1,
        initial=lambda nd: 18 if nd.region == cyt else 144,
    )
    k = rxd.Species(
        [cyt, ecs],
        name="k",
        charge=1,
        initial=lambda nd: 140 if nd.region == cyt else 3,
    )

    nai, nao, ki, ko = na[cyt], na[ecs], k[cyt], k[ecs]
    exp = rxd.rxdmath.exp
    pump = (0.8 / (1.0 + exp((25.0 - nai) / 3.0))) * (1.0 / (1.0 + exp(3.5 - ko)))
    volume_scale = 1e-18 * rxd.constants.NA() * (dend.diam / 4.0)
    pump_current = rxd.MultiCompartmentReaction(
        2 * ko + 3 * nai,
        2 * ki + 3 * nao,
        pump * volume_scale,
        mass_action=False,
        membrane=mem,
        membrane_flux=True,
    )
    model = (
        dend,
        cyt,
        ecs,
        mem,
        na,
        k,
        nai,
        nao,
        ki,
        ko,
        pump,
        volume_scale,
        pump_current,
    )
    yield (neuron_instance, model)


def test_currents(model_pump):
    """Test currents generated by a Na/K-pump fixed step integration."""

    neuron_instance, model = model_pump
    h, rxd, data, save_path = neuron_instance
    # check changing the units after initialization
    h.nrnunit_use_legacy(False)
    h.finitialize(-65)
    h.nrnunit_use_legacy(True)
    h.continuerun(10)
    if not save_path:
        max_err = compare_data(data)
        assert max_err < tol


def test_currents_cvode(model_pump):
    """Test currents generated by a Na/K-pump variable step integration."""

    neuron_instance, model = model_pump
    h, rxd, data, save_path = neuron_instance
    # check changing the units after initialization
    h.CVode().active(True)
    h.nrnunit_use_legacy(False)
    h.finitialize(-65)
    h.nrnunit_use_legacy(True)
    h.continuerun(10)
    if not save_path:
        max_err = compare_data(data)
        assert max_err < tol
