#! @PYTHON@
import subprocess
import sys
import os
import glob
from collections import defaultdict

#DOCUMENTATION/TUTORIAL @ https://gitlab.dkrz.de/mpim-sw/cdo/-/wikis/Writing-CDO-tests/edit

BUILD_DIR = "@abs_top_builddir@"
CDO="@abs_top_builddir@/src/cdo $CDO_DEBUG"
DATAPATH="@abs_top_srcdir@/test/data"
lineLength = 80
"""
Overview:
print_flush
cdo_check_req

class TAPTest => group of commands
class TestModule => group of TAPTests
"""
formats= {"srv"    : "SERVICE",
          "ext"    : "EXTRA",
          "ieg"    : "IEG",
          "grb"    : "GRIB",
          "grb2"   : "GRIB_API",
          "nc"     : "netCDF",
          "nc2"    : "netCDF2",
          "nc4"    : "netCDF4",
          "nc5"    : "netCDF5",
          "nczarr" : "netCDF zarr"
}

formats = defaultdict(lambda: "undefined", formats)

def fileformat(p_format):
    return (cdo_check_req("has-{}".format(p_format)), formats[p_format])

def print_flush(*args):
    """
    print_flush function that uses options flush=True.
    Required for the correct order of output
    """
    print(*args)
    sys.stdout.flush()

def print_seperator(symbol,length,spacing=0):
    section = symbol + (spacing *" ")
    repititions = int(length / (spacing  + len(symbol)))
    print_flush((repititions * section).strip())


def run_command(commands):
    call_status = -999
    with subprocess.Popen(
        commands, shell=True, universal_newlines=True
    ) as proc:
        proc.wait()
        call_status = proc.returncode

    return call_status

def clean(file_or_wildcard):
    files = glob.glob(file_or_wildcard)
    if(len(files) == 0):
        print_flush("WARNING: no files were cleaned ")

    for g in files:
        if(DATAPATH in g):
            print_flush("ERROR, trying to clean up a file from the test data repository")
            exit(-1)

        if os.path.exists(g):
            os.remove(g)

def cdo_check_req(req):
    """
    Checks the given requirement (e.g HAS_THREADS) and returns the output
    of the cdo --config has-<name> options with given requirement
    """
    call = "{CDO} --config {req}".format(CDO=CDO, req=req.lower())
    with subprocess.Popen(
        [call], shell=True, universal_newlines=True, stdout=subprocess.PIPE
    ) as proc:
        proc.wait()
        out, _ = proc.communicate()

    return True if "yes" in out else False

#-------------------------------------------------------------------

class TAPTest:
    def __init__(self, msg=""):
        self.commands = []
        self.skip = False
        self.message = msg
        self.cleanFiles = []
        self.expectedRetVal= defaultdict(lambda: 0, [])

    def skip_test(self,msg):
        self.skip = True
        self.message = msg

    def execute_commands(self):
        for c in self.commands:
            wrap_call = True
            if any(command in c for command in ["cmp","diff"]):
                print_flush("Starting Check: ", c)
                warp_call = False
            else:
                print_flush("Starting Cdo Call: ", c)

            wrap_with = os.getenv("CDO_TEST_PREPEND")
            if(wrap_call and wrap_with is not None):
                call_status = run_command(wrap_with + " " + c)
            else:
                call_status = run_command(c)

            if call_status != self.expectedRetVal[c]:
                print_flush("ERROR: Unexpected return value: ", call_status, "expected:", self.expectedRetVal[c])
                print_flush("--- FAILURE ---")
                return -1;

            print_seperator('-',lineLength,3)
        return 0

    def go(self):

        if(self.skip):
            return 77

        print_seperator('-',lineLength,3)
        status = self.execute_commands()
        if(status != -1):
            print_flush("+++ SUCCESS +++")

        self.cleanUp()

        return status;

    def cleanUp(self):
        for f in self.cleanFiles:
            clean(f)

    def diff(self,a,b):
        self.add("{} diff {} {}".format(CDO,a,b))

    def clean(self,*args):
        for a in args:
            self.cleanFiles.append(a)

    def add(self, command,expectedRetVal=0):
        self.commands.append(command)
        self.expectedRetVal[command]=expectedRetVal
        #-------------------------------------------------------------------

class TestModule:
    def __init__(self):
        self.cleanFiles = []
        self.commands = []
        self.testID = -1;


    def cleanUp(self):
        for f in self.cleanFiles:
            clean(f)

    def __get_num_test(self):
        num = len([c for c in self.commands if c[1] is True])
        if(num == 0):
            print_flush("ERROR NO TESTS EXECUTED")
            exit(-1)

        return num

    def add(self, test):
        self.commands.append((test,True))

    def add_skip(self,msg):
        t=TAPTest()
        t.skip_test(msg)
        self.commands.append((t,True))


    def print(self):
        for t in self.commands:
            for c in t[0].commands:
                print_flush(c)

            print_flush()

    def prepare(self,command):
        self.commands.append((command,False))

    def clean(self,*args):
        for a in args:
            self.cleanFiles.append(a)

    def __prepare_data(self,c):
        print_flush("Preparing data:",c)
        status = run_command(c)
        if(status != 0):
            print_flush("ERROR: Fail in test preperation")
            for t in self.commands:
                print_flush("not ok: test not executed due to error in test preperation")

            exit(-1)

    def __run_check(self,t):
        print_flush("Running test: %i" % self.testID)
        test_status = t.go()
        if(test_status == 0):
            print_flush(f'ok: {t.message}')
            return 0
        elif(test_status == 77):
            print_flush("ok: # SKIP: {}".format(t.message))
            return 77;
        else:
            print_flush("not ok: {}".format(t.message))
            return -1;

        print_seperator('-',lineLength)
        self.testID += 1
        print_flush()

    def run(self):
        print_flush("1..%s" % self.__get_num_test())
        print_seperator('=',lineLength)
        retval = 0

        self.testID = 1
        for c in self.commands:
            if(c[1] == False):
                self.__prepare_data(c[0])
            else:
                retval += abs(self.__run_check(c[0]))

        self.cleanUp()
        exit(0) # important to exit here: tap tests/ctest need this return value

__all__ = ["cdo_check_req","TAPTest","TestModule","print_flush","DATAPATH","CDO","BUILD_DIR","fileformat"]

def main():
    testMod = TestModule()
    t = TAPTest("check if cdoTest works")
    t.add(f'{CDO} --operators')
    testMod.add(t)

    t = TAPTest("check if cdoTest works error is returned")
    t.add(f'{CDO} -add adasd', 1)
    testMod.add(t)
    retval = testMod.run()
    return retval

if __name__ == '__main__':
    sys.exit(main())  # next section explains the use of sys.exit

