# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.

import os
import sys
from collections import defaultdict
from io import StringIO

import buildconfig
import yaml
from mozbuild.dirutils import ensureParentDir
from mozbuild.preprocessor import Preprocessor
from mozbuild.util import FileAvoidWrite

VALID_KEYS = {
    "name",
    "type",
    "value",
    "mirror",
    "do_not_use_directly",
    "include",
    "rust",
    "set_spidermonkey_pref",
}

# Each key is a C++ type; its value is the equivalent non-atomic C++ type.
VALID_BOOL_TYPES = {
    "bool": "bool",
    # These ones are defined in StaticPrefsBase.h.
    "RelaxedAtomicBool": "bool",
    "ReleaseAcquireAtomicBool": "bool",
    "SequentiallyConsistentAtomicBool": "bool",
}

VALID_TYPES = VALID_BOOL_TYPES.copy()
VALID_TYPES.update(
    {
        "int32_t": "int32_t",
        "uint32_t": "uint32_t",
        "float": "float",
        # These ones are defined in StaticPrefsBase.h.
        "RelaxedAtomicInt32": "int32_t",
        "RelaxedAtomicUint32": "uint32_t",
        "ReleaseAcquireAtomicInt32": "int32_t",
        "ReleaseAcquireAtomicUint32": "uint32_t",
        "SequentiallyConsistentAtomicInt32": "int32_t",
        "SequentiallyConsistentAtomicUint32": "uint32_t",
        "AtomicFloat": "float",
        "String": None,
        "DataMutexString": "nsACString",
    }
)

# Map non-atomic C++ types to equivalent Rust types.
RUST_TYPES = {
    "bool": "bool",
    "int32_t": "i32",
    "uint32_t": "u32",
    "float": "f32",
    "DataMutexString": "nsCString",
}

HEADER_LINE = (
    "// This file was generated by generate_static_pref_list.py from {input_filenames}."
    " DO NOT EDIT."
)

MIRROR_TEMPLATES = {
    "never": """\
NEVER_PREF("{name}", {typ}, {value})
""",
    "once": """\
ONCE_PREF(
  "{name}",
   {base_id},
   {full_id},
  {typ}, {value}
)
""",
    "always": """\
ALWAYS_PREF(
  "{name}",
   {base_id},
   {full_id},
  {typ}, {value}
)
""",
    "always_datamutex": """\
ALWAYS_DATAMUTEX_PREF(
  "{name}",
   {base_id},
   {full_id},
  {typ}, {value}
)
""",
}

STATIC_PREFS_GROUP_H_TEMPLATE1 = """\
// Include it to gain access to StaticPrefs::{group}_*.

#ifndef mozilla_StaticPrefs_{group}_h
#define mozilla_StaticPrefs_{group}_h
"""

STATIC_PREFS_GROUP_H_TEMPLATE2 = """\
#include "mozilla/StaticPrefListBegin.h"
#include "mozilla/StaticPrefList_{group}.h"
#include "mozilla/StaticPrefListEnd.h"

#endif  // mozilla_StaticPrefs_{group}_h
"""

STATIC_PREFS_C_GETTERS_TEMPLATE = """\
extern "C" {typ} StaticPrefs_{full_id}() {{
  return mozilla::StaticPrefs::{full_id}();
}}
"""

STATIC_PREFS_C_GETTERS_NSSTRING_TEMPLATE = """\
extern "C" void StaticPrefs_{full_id}(nsACString *result) {{
  const auto preflock = mozilla::StaticPrefs::{full_id}();
  result->Append(*preflock);
}}
"""


def error(msg):
    raise ValueError(msg)


def mk_id(name):
    "Replace '.' and '-' with '_', e.g. 'foo.bar-baz' becomes 'foo_bar_baz'."
    return name.replace(".", "_").replace("-", "_")


def mk_group(pref):
    name = pref["name"]
    return mk_id(name.split(".", 1)[0])


def check_pref_list(pref_list):
    # Pref names seen so far. Used to detect any duplicates.
    seen_names = set()

    # The previous pref. Used to detect mis-ordered prefs.
    prev_pref = None

    for pref in pref_list:
        # Check all given keys are known ones.
        for key in pref:
            if key not in VALID_KEYS:
                error(f"invalid key `{key}`")

        # 'name' must be present, valid, and in the right section.
        if "name" not in pref:
            error("missing `name` key")
        name = pref["name"]
        if type(name) is not str:
            error(f"non-string `name` value `{name}`")
        if "." not in name:
            error(f"`name` value `{name}` lacks a '.'")
        if name in seen_names:
            error(f"`{name}` pref is defined more than once")
        seen_names.add(name)

        # Prefs must be ordered appropriately.
        if prev_pref:
            if mk_group(prev_pref) > mk_group(pref):
                error(
                    "`{}` pref must come before `{}` pref".format(
                        name, prev_pref["name"]
                    )
                )

        # 'type' must be present and valid.
        if "type" not in pref:
            error(f"missing `type` key for pref `{name}`")
        typ = pref["type"]
        if typ not in VALID_TYPES:
            error(f"invalid `type` value `{typ}` for pref `{name}`")

        # 'value' must be present and valid.
        if "value" not in pref:
            error(f"missing `value` key for pref `{name}`")
        value = pref["value"]
        if typ == "String" or typ == "DataMutexString":
            if type(value) is not str:
                error(
                    f"non-string `value` value `{value}` for `{typ}` pref `{name}`; "
                    "add double quotes"
                )
        elif typ in VALID_BOOL_TYPES:
            if value not in (True, False):
                error(f"invalid boolean value `{value}` for pref `{name}`")

        # 'mirror' must be present and valid.
        if "mirror" not in pref:
            error(f"missing `mirror` key for pref `{name}`")
        mirror = pref["mirror"]
        if typ.startswith("DataMutex"):
            mirror += "_datamutex"
        if mirror not in MIRROR_TEMPLATES:
            error(f"invalid `mirror` value `{mirror}` for pref `{name}`")

        # Check 'do_not_use_directly' if present.
        if "do_not_use_directly" in pref:
            do_not_use_directly = pref["do_not_use_directly"]
            if type(do_not_use_directly) is not bool:
                error(
                    f"non-boolean `do_not_use_directly` value `{do_not_use_directly}` for pref "
                    f"`{name}`"
                )
            if do_not_use_directly and mirror == "never":
                error(
                    "`do_not_use_directly` uselessly set with `mirror` value "
                    "`never` for pref `{}`".format(pref["name"])
                )

        # Check 'include' if present.
        if "include" in pref:
            include = pref["include"]
            if type(include) is not str:
                error(f"non-string `include` value `{include}` for pref `{name}`")
            if include.startswith("<") and not include.endswith(">"):
                error(
                    f"`include` value `{include}` starts with `<` but does not "
                    f"end with `>` for pref `{name}`"
                )

        # Check 'rust' if present.
        if "rust" in pref:
            rust = pref["rust"]
            if type(rust) is not bool:
                error(f"non-boolean `rust` value `{rust}` for pref `{name}`")
            if rust and mirror == "never":
                error(
                    "`rust` uselessly set with `mirror` value `never` for "
                    "pref `{}`".format(pref["name"])
                )

        prev_pref = pref


def generate_code(pref_list, input_filenames):
    first_line = HEADER_LINE.format(input_filenames=", ".join(input_filenames))

    # The required includes for StaticPrefs_<group>.h.
    includes = defaultdict(set)

    # StaticPrefList_<group>.h contains all the pref definitions for this
    # group.
    static_pref_list_group_h = defaultdict(lambda: [first_line, ""])

    # StaticPrefsCGetters.cpp contains C getters for all the mirrored prefs,
    # for use by Rust code.
    static_prefs_c_getters_cpp = [first_line, ""]

    # static_prefs.rs contains C getter declarations and a macro.
    static_prefs_rs_decls = []
    static_prefs_rs_macro = []

    # Generate the per-pref code (spread across multiple files).
    for pref in pref_list:
        name = pref["name"]
        typ = pref["type"]
        value = pref["value"]
        mirror = pref["mirror"]
        do_not_use_directly = pref.get("do_not_use_directly")
        include = pref.get("include")
        rust = pref.get("rust")

        base_id = mk_id(pref["name"])
        full_id = base_id
        if mirror == "once":
            full_id += "_AtStartup"
        if do_not_use_directly:
            full_id += "_DoNotUseDirectly"
        if typ.startswith("DataMutex"):
            mirror += "_datamutex"

        group = mk_group(pref)

        if include:
            if not include.startswith("<"):
                # It's not a system header. Add double quotes.
                include = f'"{include}"'
            includes[group].add(include)

        if typ == "String":
            # Quote string literals, and escape double-quote chars.
            value = '"{}"'.format(value.replace('"', '\\"'))
        elif typ == "DataMutexString":
            # Quote string literals, and escape double-quote chars.
            value = '"{}"_ns'.format(value.replace('"', '\\"'))
        elif typ in VALID_BOOL_TYPES:
            # Convert Python bools to C++ bools.
            if value is True:
                value = "true"
            elif value is False:
                value = "false"

        # Append the C++ definition to the relevant output file's code.
        static_pref_list_group_h[group].append(
            MIRROR_TEMPLATES[mirror].format(
                name=name,
                base_id=base_id,
                full_id=full_id,
                typ=typ,
                value=value,
            )
        )

        if rust:
            passed_type = VALID_TYPES[typ]
            if passed_type == "nsACString":
                # Generate the C getter.
                static_prefs_c_getters_cpp.append(
                    STATIC_PREFS_C_GETTERS_NSSTRING_TEMPLATE.format(full_id=full_id)
                )

                # Generate the C getter declaration, in Rust.
                decl = "    pub fn StaticPrefs_{full_id}(result: *mut nsstring::nsACString);"
                static_prefs_rs_decls.append(decl.format(full_id=full_id))

                # Generate the Rust macro entry.
                macro = '    ("{name}") => (unsafe {{ let mut result = $crate::nsCString::new(); $crate::StaticPrefs_{full_id}(&mut *result); result }});'
                static_prefs_rs_macro.append(macro.format(name=name, full_id=full_id))

            else:
                # Generate the C getter.
                static_prefs_c_getters_cpp.append(
                    STATIC_PREFS_C_GETTERS_TEMPLATE.format(
                        typ=passed_type, full_id=full_id
                    )
                )

                # Generate the C getter declaration, in Rust.
                decl = "    pub fn StaticPrefs_{full_id}() -> {typ};"
                static_prefs_rs_decls.append(
                    decl.format(full_id=full_id, typ=RUST_TYPES[passed_type])
                )

                # Generate the Rust macro entry.
                macro = (
                    '    ("{name}") => (unsafe {{ $crate::StaticPrefs_{full_id}() }});'
                )
                static_prefs_rs_macro.append(macro.format(name=name, full_id=full_id))

        # Delete this so that `group` can be reused below without Flake8
        # complaining.
        del group

    # StaticPrefListAll.h contains one `#include "mozilla/StaticPrefList_X.h`
    # line per pref group.
    static_pref_list_all_h = [first_line, ""]
    static_pref_list_all_h.extend(
        f'#include "mozilla/StaticPrefList_{group}.h"'
        for group in sorted(static_pref_list_group_h)
    )
    static_pref_list_all_h.append("")

    # StaticPrefsAll.h contains one `#include "mozilla/StaticPrefs_X.h` line per
    # pref group.
    static_prefs_all_h = [first_line, ""]
    static_prefs_all_h.extend(
        f'#include "mozilla/StaticPrefs_{group}.h"'
        for group in sorted(static_pref_list_group_h)
    )
    static_prefs_all_h.append("")

    # StaticPrefs_<group>.h wraps StaticPrefList_<group>.h. It is the header
    # used directly by application code.
    static_prefs_group_h = defaultdict(list)
    for group in sorted(static_pref_list_group_h):
        static_prefs_group_h[group] = [first_line]
        static_prefs_group_h[group].append(
            STATIC_PREFS_GROUP_H_TEMPLATE1.format(group=group)
        )
        if group in includes:
            # Add any necessary includes, from 'h_include' values.
            for include in sorted(includes[group]):
                static_prefs_group_h[group].append(f"#include {include}")
            static_prefs_group_h[group].append("")
        static_prefs_group_h[group].append(
            STATIC_PREFS_GROUP_H_TEMPLATE2.format(group=group)
        )

    # static_prefs.rs contains the Rust macro getters.
    static_prefs_rs = [first_line, "", "pub use nsstring::nsCString;", 'extern "C" {']
    static_prefs_rs.extend(static_prefs_rs_decls)
    static_prefs_rs.extend(["}", "", "#[macro_export]", "macro_rules! pref {"])
    static_prefs_rs.extend(static_prefs_rs_macro)
    static_prefs_rs.extend(["}", ""])

    def fold(lines):
        return "\n".join(lines)

    return {
        "static_pref_list_all_h": fold(static_pref_list_all_h),
        "static_prefs_all_h": fold(static_prefs_all_h),
        "static_pref_list_group_h": {
            k: fold(v) for k, v in static_pref_list_group_h.items()
        },
        "static_prefs_group_h": {k: fold(v) for k, v in static_prefs_group_h.items()},
        "static_prefs_c_getters_cpp": fold(static_prefs_c_getters_cpp),
        "static_prefs_rs": fold(static_prefs_rs),
    }


def emit_code(fd, *pref_list_filenames):
    pp = Preprocessor()
    pp.context.update(buildconfig.defines["ALLDEFINES"])

    # A necessary hack until MOZ_DEBUG_FLAGS are part of buildconfig.defines.
    if buildconfig.substs.get("MOZ_DEBUG"):
        pp.context["DEBUG"] = "1"

    if buildconfig.substs.get("TARGET_CPU") == "aarch64":
        pp.context["MOZ_AARCH64"] = True

    if buildconfig.substs.get("MOZ_ANDROID_CONTENT_SERVICE_ISOLATED_PROCESS"):
        pp.context["MOZ_ANDROID_CONTENT_SERVICE_ISOLATED_PROCESS"] = True

    pref_list = []
    input_files = []
    for this_filename in pref_list_filenames:
        pp.out = StringIO()
        pp.do_filter("substitution")
        pp.do_include(this_filename)

        try:
            this_pref_list = yaml.safe_load(pp.out.getvalue())
            check_pref_list(this_pref_list)
            pref_list.extend(this_pref_list)
            input_files.append(
                os.path.relpath(
                    this_filename,
                    os.environ.get("GECKO_PATH", os.environ.get("TOPSRCDIR")),
                )
            )
        except (OSError, ValueError) as e:
            print(f"{this_filename}: error:\n  {e}\n")
            sys.exit(1)

    code = generate_code(pref_list, input_files)
    # When generating multiple files from a script, the build system treats the
    # first named output file (StaticPrefListAll.h in this case) specially -- it
    # is created elsewhere, and written to via `fd`.
    fd.write(code["static_pref_list_all_h"])

    # We must create the remaining output files ourselves. This requires
    # creating the output directory directly if it doesn't already exist.
    ensureParentDir(fd.name)
    init_dirname = os.path.dirname(fd.name)
    dirname = os.path.dirname(init_dirname)

    with FileAvoidWrite(os.path.join(dirname, "StaticPrefsAll.h")) as fd:
        fd.write(code["static_prefs_all_h"])

    for group, text in sorted(code["static_pref_list_group_h"].items()):
        filename = f"StaticPrefList_{group}.h"
        with FileAvoidWrite(os.path.join(init_dirname, filename)) as fd:
            fd.write(text)

    for group, text in sorted(code["static_prefs_group_h"].items()):
        filename = f"StaticPrefs_{group}.h"
        with FileAvoidWrite(os.path.join(dirname, filename)) as fd:
            fd.write(text)

    with FileAvoidWrite(os.path.join(init_dirname, "StaticPrefsCGetters.cpp")) as fd:
        fd.write(code["static_prefs_c_getters_cpp"])

    with FileAvoidWrite(os.path.join(dirname, "static_prefs.rs")) as fd:
        fd.write(code["static_prefs_rs"])
