#!/usr/bin/env python3

"""
Various checks for certain C++ style rules that we choose for BornAgain,
in addition to the rules imposed by clang-format.

Copyright Forschungszentrum Jülich GmbH 2025.
License:  Public Domain
"""

import os, re, subprocess, sys

####################################################################################################
# generic utilities
####################################################################################################

def bash(cmd):
    # -i to respect .bashrc
    process = subprocess.Popen(['bash', '-i', '-c', cmd],
                               stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE)

    # Wait for the process to finish and get the output
    stdout, stderr = process.communicate()

    # Check if there were any errors
    if process.returncode != 0:
        print(f"Error: {stderr.decode('utf-8')}")
        sys.exit(1)
    return stdout.decode('utf-8').rstrip('\n').split('\n')

def split_fname(fn):
    """
    return stem, suffix
    """
    m = re.match(r'(.*)\.(\w+)$', fn)
    if m:
        return (m.group(1), m.group(2))
    return (fn, "")

####################################################################################################
# checks
####################################################################################################

def check_parallel_includes(stem, th, tc, errs):

    included_by_h = []
    for m in re.finditer(r'#include ["<](.*?)[">]', th):
        included_by_h.append(m.group(1))

    if len(included_by_h) == 0:
        return

    alternatives = '|'.join(included_by_h)
    if not re.search(r'#include ["<](' + alternatives + r')[">]', tc):
        return

    # found error, now generate informative message
    duplicates = []
    for inc in included_by_h:
        if re.search(r'#include ["<](' + inc + r')[">]', tc):
            duplicates.append(inc)

    errs.append(f"duplicate includes in source pair {stem}.h|cpp: {', '.join(duplicates)}")

def rm_namespace(t):
    return re.sub(r'\nnamespace ([\w:]+ )?{\n(.*\n)*?}.*', r'\nnamespace /*...*/', t)

def check_block(fn, t, name, errs):

    if name == "include":
        rex = r'(#include) .*?'
    elif name == "using":
        rex = r'(using) [^;]*?;'
        t = rm_namespace(t)
    elif name == "fwd decl":
        rex = r'(class|struct) \w+;'
        t = rm_namespace(t)
    matches = [m for m in re.finditer(r'(\n*|template.*)\n('+ rex + r'([ ]*//.*)?\n)+(\n*)', t)]

    if len(matches) > 1:

        # tolerate multiple #include blocks if they are separated by some #if or #pragma
        if name == "include" and re.search('#include.*#(if|pragma).*#include', t, re.DOTALL):
            return

        errs.append(f"several {name} blocks in file {fn}")

        return

    elif len(matches) == 0:
        return

    m = matches[0]

    if re.match(r'^template', m.group(1)):
        return
    if m.group(1) == '':
        if not re.match(rex, t):
            errs.append(f"missing blank line above {name} block in file {fn}")
        else:
            print(f'In file {fn} block {name} at head')
    elif m.group(1) != '\n':
        errs.append(f"more than one blank line above {name} block in file {fn}")
    if m.group(5) == '':
        errs.append(f"missing blank line below {name} block in file {fn}")
    elif m.group(5) != '\n':
        errs.append(f"more than one blank line below {name} block in file {fn}")

def check_fwd_decl_sorted(fn, t, errs):
    for m in re.finditer(r'((class|struct) (\w+);([ ]*//.*)?\n)+', t):
        decls_txt = m.group(0)
        decls_arr = decls_txt.rstrip('\n').split('\n')
        if decls_arr != sorted(decls_arr):
            errs.append(f"forward declarations not sorted in file {fn}")
            return

def check_no_trivial_getters_in_cppfile(fc, t, errs):
    m = re.search(r'\n\S+ (\w+::[a-z]\w*)\(\) const\n{\s*return m_\w+;\s*}', tc)
    if m:
        errs.append(f"trivial getter {m.group(1)} in file {fc}, should go to header file")

####################################################################################################
# main
####################################################################################################

if __name__ == "__main__":

    if (len(sys.argv)>1 and sys.argv[1]=="-h") or len(sys.argv)!=1:
        print("To be called without any arguments")
        sys.exit(1)

    errs = []

    allfiles = bash('find . -not \( -path ./auto -prune -o -path ./devtools -prune -o -name *ThirdParty* -prune -o -name *3rdparty* -prune -o -name *3rdparty* -prune -o -path ./build -prune -o -path ./debug -prune -o -path ./cover -prune -o -path ./tidy -prune -o -path ./qbuild -prune \) -a -type f -a \( -name "*.cpp" -o -name "*.c" -o -name "*.h" \)')

    # tests on each file:
    for fn in allfiles:
        stem, suffix = split_fname(fn)
        with open(fn, 'r') as fd:
            t = fd.read()
        check_block(fn, t, 'include', errs)
        check_block(fn, t, 'using', errs)
        if suffix == 'h':
            check_block(fn, t, 'fwd decl', errs)
            check_fwd_decl_sorted(fn, t, errs)

    # tests on file pairs:
    for fn in allfiles:
        stem, suffix = split_fname(fn)
        if suffix != 'cpp':
            continue
        fh = stem+'.h'
        fc = fn
        if not os.path.exists(fh):
            continue
        with open(fh, 'r') as fd:
            th = fd.read()
        with open(fc, 'r') as fd:
            tc = fd.read()
        check_parallel_includes(stem, th, tc, errs)

    if len(errs) > 0:
        print("Found error(s):")
        for err in errs:
            print("- " + err)
        sys.exit(1)

    sys.exit(0)
