"""
Bliss repository (used in production on the publishing side):

* Bliss branch master: blissdata version 2.0.0 (python_requires = >=3.9)
* Bliss branch  2.1.x: blissdata version 1.1.2 (python_requires = >=3.8)
* Bliss branch  2.0.x: blissdata version 1.0.3 (python_requires = >=3.8, <3.10)
* Bliss branch 1.11.x: blissdata version 0.3.4 (python_requires = >=3.7, <3.10)
"""

import bliss  # noqa F401 patch environment the way Bliss wants it (not standard)

import os
import sys
from pprint import pprint

if sys.version_info >= (3, 8):
    from importlib.metadata import version
else:
    from importlib_metadata import version
from packaging.specifiers import SpecifierSet

import gevent
import blissdemo
from bliss.config import static
from bliss.shell import standard

from ewoksdata.data.blissdata import last_lima_image
from ewoksdata.data.blissdata import iter_bliss_scan_data_from_memory
from ewoksdata.data.blissdata import iter_bliss_scan_data_from_memory_slice


_BLISS_VERSION = version("bliss")
_BLISSDATA_VERSION = version("blissdata")

if _BLISS_VERSION in SpecifierSet("<2", prereleases=True):
    from bliss.data.node import get_session_node
else:
    from blissdata.beacon.data import BeaconData
    from blissdata.redis_engine.store import DataStore
    from blissdata.redis_engine.exceptions import NoScanAvailable


def test_iter_memory(scan_key) -> None:
    lima_names = ["difflab6"]
    counter_names = ["diode1"]
    print(f"Iterate scan {scan_key} ...")
    n = 0
    for data in iter_bliss_scan_data_from_memory(scan_key, lima_names, counter_names):
        pprint({k: v.shape for k, v in data.items()})
        n += 1
    assert n == 10
    print(f"Received all data from {scan_key}.")


def test_iter_slice_memory(scan_key) -> None:
    lima_names = ["difflab6"]
    counter_names = ["diode1"]
    print(f"Iterate scan slice {scan_key} ...")
    n = 0
    for data in iter_bliss_scan_data_from_memory_slice(
        scan_key, lima_names, counter_names, slice_range=(3, 5)
    ):
        pprint({k: v.shape for k, v in data.items()})
        n += 1
    assert n == 2
    print(f"Received all data from {scan_key}.")


if _BLISSDATA_VERSION in SpecifierSet("<1", prereleases=True):

    def test_last_lima_image(scan_key) -> None:
        print(f"Get last 'difflab6' image from {scan_key} ...")
        gevent.sleep(5)

        db_name = f"{scan_key}:timer:difflab6:image"

        image = last_lima_image(db_name)
        pprint({"difflab6": image.shape})
        assert image.ndim == 2

elif _BLISSDATA_VERSION in SpecifierSet("<2", prereleases=True):

    def test_last_lima_image(scan_key) -> None:
        print(f"Get last 'difflab6' image from {scan_key} ...")
        gevent.sleep(5)

        redis_url = BeaconData().get_redis_data_db()
        data_store = DataStore(redis_url)
        scan = data_store.load_scan(scan_key)
        channel_info = scan.streams["difflab6:image"].info

        image = last_lima_image(channel_info)
        pprint({"difflab6": image.shape})
        assert image.ndim == 2

else:

    def test_last_lima_image(scan_key) -> None:
        print(f"Get last 'difflab6' image from {scan_key} ...")
        gevent.sleep(5)

        image = last_lima_image(scan_key, "difflab6")
        pprint({"difflab6": image.shape})
        assert image.ndim == 2


if _BLISSDATA_VERSION in SpecifierSet("<1", prereleases=True):
    _TESTS = test_iter_memory, test_last_lima_image
else:
    _TESTS = test_iter_memory, test_last_lima_image, test_iter_slice_memory


if _BLISS_VERSION in SpecifierSet("<2", prereleases=True):

    def init_execute_tests() -> None:
        session = get_session_node("demo_session")
        scan_types = ("scan", "scan_group")
        event_iterator = session.walk_on_new_events(exclude_children=scan_types)
        return (event_iterator,)

    def execute_tests(event_iterator) -> None:
        it_tests = iter(_TESTS)
        run_test = next(it_tests)

        for ev in event_iterator:
            if ev.type == ev.type.NEW_NODE and ev.node.type == "scan":
                db_name = ev.node.db_name
                run_test(db_name)
                try:
                    run_test = next(it_tests)
                except StopIteration:
                    break

else:

    def init_execute_tests() -> None:
        redis_url = BeaconData().get_redis_data_db()
        data_store = DataStore(redis_url)
        since = data_store.get_last_scan_timetag()
        return data_store, since

    def execute_tests(data_store, since) -> None:
        for run_test in _TESTS:
            while True:
                try:
                    since, scan_key = data_store.get_next_scan(since=since, timeout=1)
                    run_test(scan_key)
                    break
                except NoScanAvailable:
                    pass


def run_scans(session, nscans):
    loopscan = session.env_dict["loopscan"]
    detectors = session.env_dict["difflab6"], session.env_dict["diode1"]
    for _ in range(nscans):
        print("Scan starts ...")
        loopscan(10, 0.1, *detectors)
        print("Scan finished.")


def start_bliss_session():
    config = static.get_config()
    bliss_session = config.get("demo_session")
    if _BLISS_VERSION in SpecifierSet(">=2.1.0dev0", prereleases=True):
        bliss_session.active_session()

    env_dict = dict()
    env_dict.update(standard.__dict__)

    assert bliss_session.setup(env_dict=env_dict), "Session setup failed"
    return bliss_session


if __name__ == "__main__":
    os.environ.setdefault("BEACON_HOST", "localhost:10001")
    os.environ.setdefault("TANGO_HOST", "localhost:10000")
    os.environ.setdefault("DEMO_ROOT", blissdemo.__path__[0])

    session = start_bliss_session()
    args = init_execute_tests()
    tests = gevent.spawn(execute_tests, *args)

    nscans = len(_TESTS)
    run_scans(session, nscans)
    tests.get(timeout=nscans * 60 + 30)
