#!/usr/bin/env python
# type: ignore
"""
Generate per-sqlstate errors from PostgreSQL source code.

The script can be run at a new PostgreSQL release to refresh the module.
"""

# Copyright (C) 2020 The Psycopg Team

import os
import re
import sys
import logging
from collections import defaultdict, namedtuple
from urllib.request import urlopen

from psycopg.errors import get_base_exception

logger = logging.getLogger()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")

Error = namedtuple("Error", "sqlstate errlabel clsname basename")


def main():
    # Note: add "master" for a preview
    classes, errors = fetch_errors("9.6 10 11 12 13 14 15 16 17 18".split())

    fn = os.path.dirname(__file__) + "/../psycopg/psycopg/errors.py"
    update_file(fn, generate_module_data(classes, errors))

    fn = os.path.dirname(__file__) + "/../docs/api/errors.rst"
    update_file(fn, generate_docs_data(classes, errors))


def parse_errors_txt(url):
    classes = {}
    errors = defaultdict(dict)

    page = urlopen(url)
    for line in page.read().decode("ascii").splitlines():
        # Strip comments and skip blanks

        if not (line := line.split("#")[0].strip()):
            continue

        # Parse a section
        if m := re.match("Section: (Class (..) - .+)", line):
            label, class_ = m.groups()
            classes[class_] = label
            continue

        # Parse an error
        if m := re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line):
            sqlstate, macro, spec = m.groups()
            # skip sqlstates without specs as they are not publicly visible
            if not spec:
                continue
            errlabel = spec.upper()
            errors[class_][sqlstate] = errlabel
            continue

        # We don't expect anything else
        raise ValueError("unexpected line:\n%s" % line)

    return classes, errors


def tag_from_version(version: str) -> str:
    if version == "master":
        return version

    tver = tuple(map(int, version.split()[0].split(".")))
    tag = "%s%s_STABLE" % (
        (tver[0] >= 10 and "REL_" or "REL"),
        version.replace(".", "_"),
    )
    return tag


def fetch_errors(versions):
    classes = {}
    errors = defaultdict(dict)

    for version in versions:
        logger.info("fetching errors from version %s", version)
        tag = tag_from_version(version)
        url = (
            "https://raw.githubusercontent.com/postgres/postgres"
            f"/refs/heads/{tag}/src/backend/utils/errcodes.txt"
        )
        c1, e1 = parse_errors_txt(url)
        classes.update(c1)

        for c, cerrs in e1.items():
            errors[c].update(cerrs)

    # clean up data

    # success and warning - never raised
    del classes["00"]
    del classes["01"]
    del errors["00"]
    del errors["01"]

    specific = {
        "38002": "ModifyingSqlDataNotPermittedExt",
        "38003": "ProhibitedSqlStatementAttemptedExt",
        "38004": "ReadingSqlDataNotPermittedExt",
        "39004": "NullValueNotAllowedExt",
        "XX000": "InternalError_",
    }

    seen = set(
        """
        Error Warning InterfaceError DataError DatabaseError ProgrammingError
        IntegrityError InternalError NotSupportedError OperationalError
        """.split()
    )

    for c, cerrs in errors.items():
        for sqstate, errlabel in list(cerrs.items()):
            if sqstate in specific:
                clsname = specific[sqstate]
            else:
                clsname = errlabel.title().replace("_", "")
            if clsname in seen:
                raise Exception("class already existing: %s" % clsname)
            seen.add(clsname)

            basename = get_base_exception(sqstate).__name__
            cerrs[sqstate] = Error(sqstate, errlabel, clsname, basename)

    return classes, errors


def generate_module_data(classes, errors):
    yield ""

    for clscode, clslabel in sorted(classes.items()):
        yield f"""
# {clslabel}
"""
        for _, e in sorted(errors[clscode].items()):
            yield f"""\
class {e.clsname}({e.basename},
    code={e.sqlstate!r}, name={e.errlabel!r}):
    pass
"""
    yield ""


def generate_docs_data(classes, errors):
    Line = namedtuple("Line", "colstate colexc colbase, sqlstate")
    lines = [Line("SQLSTATE", "Exception", "Base exception", None)]

    for clscode in sorted(classes):
        for _, error in sorted(errors[clscode].items()):
            lines.append(
                Line(
                    f"``{error.sqlstate}``",
                    f"`!{error.clsname}`",
                    f"`!{error.basename}`",
                    error.sqlstate,
                )
            )

    widths = [max(len(line[c]) for line in lines) for c in range(3)]
    h = Line(*(["=" * w for w in widths] + [None]))
    lines.insert(0, h)
    lines.insert(2, h)
    lines.append(h)

    h1 = "-" * (sum(widths) + len(widths) - 1)
    sqlclass = None

    yield ""
    for line in lines:
        cls = line.sqlstate[:2] if line.sqlstate else None
        if cls and cls != sqlclass:
            yield re.sub(r"(Class\s+[^\s]+)", r"**\1**", classes[cls])
            yield h1
            sqlclass = cls

        yield (
            "%-*s %-*s %-*s"
            % (
                widths[0],
                line.colstate,
                widths[1],
                line.colexc,
                widths[2],
                line.colbase,
            )
        ).rstrip()

    yield ""


def update_file(fn, new_lines):
    logger.info("updating %s", fn)

    with open(fn) as f:
        lines = f.read().splitlines()

    istart, iend = (
        i
        for i, line in enumerate(lines)
        if re.match(r"\s*(#|\.\.)\s*autogenerated:\s+(start|end)", line)
    )

    lines[istart + 1 : iend] = new_lines

    with open(fn, "w") as f:
        for line in lines:
            f.write(line + "\n")


if __name__ == "__main__":
    sys.exit(main())
