#!/usr/bin/env python3

import sys
import json
import textwrap

max_pubkeys = 0

if len(sys.argv) < 2:
    print(
        "This script converts BIP MuSig2 test vectors in a given directory to a C file that can be used in the test framework."
    )
    print("Usage: %s <dir>" % sys.argv[0])
    sys.exit(1)


def hexstr_to_intarray(str):
    return ", ".join([f"0x{b:02X}" for b in bytes.fromhex(str)])


def create_init(name):
    return """
static const struct musig_%s_vector musig_%s_vector = {
""" % (
        name,
        name,
    )


def init_array(key):
    return textwrap.indent("{ %s },\n" % hexstr_to_intarray(data[key]), 4 * " ")


def init_arrays(key):
    s = textwrap.indent("{\n", 4 * " ")
    s += textwrap.indent(
        ",\n".join(["{ %s }" % hexstr_to_intarray(x) for x in data[key]]), 8 * " "
    )
    s += textwrap.indent("\n},\n", 4 * " ")
    return s


def init_indices(array):
    return " %d, { %s }" % (
        len(array),
        ", ".join(map(str, array) if len(array) > 0 else "0"),
    )


def init_is_xonly(case):
    if len(case["tweak_indices"]) > 0:
        return ", ".join(map(lambda x: "1" if x else "0", case["is_xonly"]))
    return "0"


def init_optional_expected(case):
    return hexstr_to_intarray(case["expected"]) if "expected" in case else 0


def init_cases(cases, f):
    s = textwrap.indent("{\n", 4 * " ")
    for (i, case) in enumerate(cases):
        s += textwrap.indent("%s\n" % f(case), 8 * " ")
    s += textwrap.indent("},\n", 4 * " ")
    return s


def finish_init():
    return "};\n"


s = (
    """/**
 * Automatically generated by %s.
 *
 * The test vectors for the KeySort function are included in this file. They can
 * be found in src/modules/extrakeys/tests_impl.h. */
"""
    % sys.argv[0]
)


s += """
enum MUSIG_ERROR {
    MUSIG_PUBKEY,
    MUSIG_TWEAK,
    MUSIG_PUBNONCE,
    MUSIG_AGGNONCE,
    MUSIG_SECNONCE,
    MUSIG_SIG,
    MUSIG_SIG_VERIFY,
    MUSIG_OTHER
};
"""

# key agg vectors
with open(sys.argv[1] + "/key_agg_vectors.json", "r") as f:
    data = json.load(f)

    max_key_indices = max(
        len(test_case["key_indices"]) for test_case in data["valid_test_cases"]
    )
    max_tweak_indices = max(
        len(test_case["tweak_indices"]) for test_case in data["error_test_cases"]
    )
    num_pubkeys = len(data["pubkeys"])
    max_pubkeys = max(num_pubkeys, max_pubkeys)
    num_tweaks = len(data["tweaks"])
    num_valid_cases = len(data["valid_test_cases"])
    num_error_cases = len(data["error_test_cases"])

    # Add structures for valid and error cases
    s += (
        """
struct musig_key_agg_valid_test_case {
    size_t key_indices_len;
    size_t key_indices[%d];
    unsigned char expected[32];
};
"""
        % max_key_indices
    )
    s += """
struct musig_key_agg_error_test_case {
    size_t key_indices_len;
    size_t key_indices[%d];
    size_t tweak_indices_len;
    size_t tweak_indices[%d];
    int is_xonly[%d];
    enum MUSIG_ERROR error;
};
""" % (
        max_key_indices,
        max_tweak_indices,
        max_tweak_indices,
    )

    # Add structure for entire vector
    s += """
struct musig_key_agg_vector {
    unsigned char pubkeys[%d][33];
    unsigned char tweaks[%d][32];
    struct musig_key_agg_valid_test_case valid_case[%d];
    struct musig_key_agg_error_test_case error_case[%d];
};
""" % (
        num_pubkeys,
        num_tweaks,
        num_valid_cases,
        num_error_cases,
    )

    s += create_init("key_agg")
    # Add pubkeys and tweaks to the vector
    s += init_arrays("pubkeys")
    s += init_arrays("tweaks")

    # Add valid cases to the vector
    s += init_cases(
        data["valid_test_cases"],
        lambda case: "{ %s, { %s }},"
        % (init_indices(case["key_indices"]), hexstr_to_intarray(case["expected"])),
    )

    def comment_to_error(case):
        comment = case["comment"]
        if "public key" in comment.lower():
            return "MUSIG_PUBKEY"
        elif "tweak" in comment.lower():
            return "MUSIG_TWEAK"
        else:
            sys.exit("Unknown error")

    # Add error cases to the vector
    s += init_cases(
        data["error_test_cases"],
        lambda case: "{ %s, %s, { %s }, %s },"
        % (
            init_indices(case["key_indices"]),
            init_indices(case["tweak_indices"]),
            init_is_xonly(case),
            comment_to_error(case),
        ),
    )

    s += finish_init()

# nonce gen vectors
with open(sys.argv[1] + "/nonce_gen_vectors.json", "r") as f:
    data = json.load(f)

    # The MuSig2 implementation only allows messages of length 32
    data["test_cases"] = list(
        filter(lambda c: c["msg"] is None or len(c["msg"]) == 64, data["test_cases"])
    )

    num_tests = len(data["test_cases"])

    s += """
struct musig_nonce_gen_test_case {
    unsigned char rand_[32];
    int has_sk;
    unsigned char sk[32];
    unsigned char pk[33];
    int has_aggpk;
    unsigned char aggpk[32];
    int has_msg;
    unsigned char msg[32];
    int has_extra_in;
    unsigned char extra_in[32];
    unsigned char expected_secnonce[97];
    unsigned char expected_pubnonce[66];
};
"""

    s += (
        """
struct musig_nonce_gen_vector {
    struct musig_nonce_gen_test_case test_case[%d];
};
"""
        % num_tests
    )

    s += create_init("nonce_gen")

    def init_array_maybe(array):
        return "%d , { %s }" % (
            0 if array is None else 1,
            hexstr_to_intarray(array) if array is not None else 0,
        )

    s += init_cases(
        data["test_cases"],
        lambda case: "{ { %s },  %s, { %s }, %s, %s, %s, { %s }, { %s } },"
        % (
            hexstr_to_intarray(case["rand_"]),
            init_array_maybe(case["sk"]),
            hexstr_to_intarray(case["pk"]),
            init_array_maybe(case["aggpk"]),
            init_array_maybe(case["msg"]),
            init_array_maybe(case["extra_in"]),
            hexstr_to_intarray(case["expected_secnonce"]),
            hexstr_to_intarray(case["expected_pubnonce"]),
        ),
    )

    s += finish_init()

# nonce agg vectors
with open(sys.argv[1] + "/nonce_agg_vectors.json", "r") as f:
    data = json.load(f)

    num_pnonces = len(data["pnonces"])
    num_valid_cases = len(data["valid_test_cases"])
    num_error_cases = len(data["error_test_cases"])

    pnonce_indices_len = 2
    for case in data["valid_test_cases"] + data["error_test_cases"]:
        assert len(case["pnonce_indices"]) == pnonce_indices_len

    # Add structures for valid and error cases
    s += """
struct musig_nonce_agg_test_case {
    size_t pnonce_indices[2];
    /* if valid case */
    unsigned char expected[66];
    /* if error case */
    int invalid_nonce_idx;
};
"""
    # Add structure for entire vector
    s += """
struct musig_nonce_agg_vector {
    unsigned char pnonces[%d][66];
    struct musig_nonce_agg_test_case valid_case[%d];
    struct musig_nonce_agg_test_case error_case[%d];
};
""" % (
        num_pnonces,
        num_valid_cases,
        num_error_cases,
    )

    s += create_init("nonce_agg")
    s += init_arrays("pnonces")

    for cases in (data["valid_test_cases"], data["error_test_cases"]):
        s += init_cases(
            cases,
            lambda case: "{ { %s }, { %s }, %d },"
            % (
                ", ".join(map(str, case["pnonce_indices"])),
                init_optional_expected(case),
                case["error"]["signer"] if "error" in case else 0,
            ),
        )
    s += finish_init()

# sign/verify vectors
with open(sys.argv[1] + "/sign_verify_vectors.json", "r") as f:
    data = json.load(f)

    # The MuSig2 implementation only allows messages of length 32
    assert list(filter(lambda x: len(x) == 64, data["msgs"]))[0] == data["msgs"][0]
    data["msgs"] = [data["msgs"][0]]

    def filter_msg32(k):
        return list(filter(lambda x: x["msg_index"] == 0, data[k]))

    data["valid_test_cases"] = filter_msg32("valid_test_cases")
    data["sign_error_test_cases"] = filter_msg32("sign_error_test_cases")
    data["verify_error_test_cases"] = filter_msg32("verify_error_test_cases")
    data["verify_fail_test_cases"] = filter_msg32("verify_fail_test_cases")

    num_pubkeys = len(data["pubkeys"])
    max_pubkeys = max(num_pubkeys, max_pubkeys)
    num_secnonces = len(data["secnonces"])
    num_pubnonces = len(data["pnonces"])
    num_aggnonces = len(data["aggnonces"])
    num_msgs = len(data["msgs"])
    num_valid_cases = len(data["valid_test_cases"])
    num_sign_error_cases = len(data["sign_error_test_cases"])
    num_verify_fail_cases = len(data["verify_fail_test_cases"])
    num_verify_error_cases = len(data["verify_error_test_cases"])

    all_cases = (
        data["valid_test_cases"]
        + data["sign_error_test_cases"]
        + data["verify_error_test_cases"]
        + data["verify_fail_test_cases"]
    )
    max_key_indices = max(len(test_case["key_indices"]) for test_case in all_cases)
    max_nonce_indices = max(
        len(test_case["nonce_indices"]) if "nonce_indices" in test_case else 0
        for test_case in all_cases
    )
    # Add structures for valid and error cases
    s += (
        """
/* Omit pubnonces in the test vectors because our partial signature verification
 * implementation is able to accept the aggnonce directly. */
struct musig_valid_case {
    size_t key_indices_len;
    size_t key_indices[%d];
    size_t aggnonce_index;
    size_t msg_index;
    size_t signer_index;
    unsigned char expected[32];
};
"""
        % max_key_indices
    )

    s += (
        """
struct musig_sign_error_case {
    size_t key_indices_len;
    size_t key_indices[%d];
    size_t aggnonce_index;
    size_t msg_index;
    size_t secnonce_index;
    enum MUSIG_ERROR error;
};
"""
        % max_key_indices
    )

    s += """
struct musig_verify_fail_error_case {
    unsigned char sig[32];
    size_t key_indices_len;
    size_t key_indices[%d];
    size_t nonce_indices_len;
    size_t nonce_indices[%d];
    size_t msg_index;
    size_t signer_index;
    enum MUSIG_ERROR error;
};
""" % (
        max_key_indices,
        max_nonce_indices,
    )

    # Add structure for entire vector
    s += """
struct musig_sign_verify_vector {
    unsigned char sk[32];
    unsigned char pubkeys[%d][33];
    unsigned char secnonces[%d][194];
    unsigned char pubnonces[%d][194];
    unsigned char aggnonces[%d][66];
    unsigned char msgs[%d][32];
    struct musig_valid_case valid_case[%d];
    struct musig_sign_error_case sign_error_case[%d];
    struct musig_verify_fail_error_case verify_fail_case[%d];
    struct musig_verify_fail_error_case verify_error_case[%d];
};
""" % (
        num_pubkeys,
        num_secnonces,
        num_pubnonces,
        num_aggnonces,
        num_msgs,
        num_valid_cases,
        num_sign_error_cases,
        num_verify_fail_cases,
        num_verify_error_cases,
    )

    s += create_init("sign_verify")
    s += init_array("sk")
    s += init_arrays("pubkeys")
    s += init_arrays("secnonces")
    s += init_arrays("pnonces")
    s += init_arrays("aggnonces")
    s += init_arrays("msgs")

    s += init_cases(
        data["valid_test_cases"],
        lambda case: "{ %s, %d, %d, %d, { %s }},"
        % (
            init_indices(case["key_indices"]),
            case["aggnonce_index"],
            case["msg_index"],
            case["signer_index"],
            init_optional_expected(case),
        ),
    )

    def sign_error(case):
        comment = case["comment"]
        if "pubkey" in comment or "public key" in comment:
            return "MUSIG_PUBKEY"
        elif "Aggregate nonce" in comment:
            return "MUSIG_AGGNONCE"
        elif "Secnonce" in comment:
            return "MUSIG_SECNONCE"
        else:
            sys.exit("Unknown sign error")

    s += init_cases(
        data["sign_error_test_cases"],
        lambda case: "{ %s, %d, %d, %d, %s },"
        % (
            init_indices(case["key_indices"]),
            case["aggnonce_index"],
            case["msg_index"],
            case["secnonce_index"],
            sign_error(case),
        ),
    )

    def verify_error(case):
        comment = case["comment"]
        if "exceeds" in comment:
            return "MUSIG_SIG"
        elif "Wrong signer" in comment or "Wrong signature" in comment:
            return "MUSIG_SIG_VERIFY"
        elif "pubnonce" in comment:
            return "MUSIG_PUBNONCE"
        elif "pubkey" in comment:
            return "MUSIG_PUBKEY"
        else:
            sys.exit("Unknown verify error")

    for cases in ("verify_fail_test_cases", "verify_error_test_cases"):
        s += init_cases(
            data[cases],
            lambda case: "{ { %s }, %s, %s, %d, %d, %s },"
            % (
                hexstr_to_intarray(case["sig"]),
                init_indices(case["key_indices"]),
                init_indices(case["nonce_indices"]),
                case["msg_index"],
                case["signer_index"],
                verify_error(case),
            ),
        )

    s += finish_init()

# tweak vectors
with open(sys.argv[1] + "/tweak_vectors.json", "r") as f:
    data = json.load(f)

    num_pubkeys = len(data["pubkeys"])
    max_pubkeys = max(num_pubkeys, max_pubkeys)
    num_pubnonces = len(data["pnonces"])
    num_tweaks = len(data["tweaks"])
    num_valid_cases = len(data["valid_test_cases"])
    num_error_cases = len(data["error_test_cases"])

    all_cases = data["valid_test_cases"] + data["error_test_cases"]
    max_key_indices = max(len(test_case["key_indices"]) for test_case in all_cases)
    max_tweak_indices = max(len(test_case["tweak_indices"]) for test_case in all_cases)
    max_nonce_indices = max(len(test_case["nonce_indices"]) for test_case in all_cases)
    # Add structures for valid and error cases
    s += """
struct musig_tweak_case {
    size_t key_indices_len;
    size_t key_indices[%d];
    size_t nonce_indices_len;
    size_t nonce_indices[%d];
    size_t tweak_indices_len;
    size_t tweak_indices[%d];
    int is_xonly[%d];
    size_t signer_index;
    unsigned char expected[32];
};
""" % (
        max_key_indices,
        max_nonce_indices,
        max_tweak_indices,
        max_tweak_indices,
    )

    # Add structure for entire vector
    s += """
struct musig_tweak_vector {
    unsigned char sk[32];
    unsigned char secnonce[97];
    unsigned char aggnonce[66];
    unsigned char msg[32];
    unsigned char pubkeys[%d][33];
    unsigned char pubnonces[%d][194];
    unsigned char tweaks[%d][32];
    struct musig_tweak_case valid_case[%d];
    struct musig_tweak_case error_case[%d];
};
""" % (
        num_pubkeys,
        num_pubnonces,
        num_tweaks,
        num_valid_cases,
        num_error_cases,
    )
    s += create_init("tweak")
    s += init_array("sk")
    s += init_array("secnonce")
    s += init_array("aggnonce")
    s += init_array("msg")
    s += init_arrays("pubkeys")
    s += init_arrays("pnonces")
    s += init_arrays("tweaks")

    s += init_cases(
        data["valid_test_cases"],
        lambda case: "{ %s, %s, %s, { %s }, %d, { %s }},"
        % (
            init_indices(case["key_indices"]),
            init_indices(case["nonce_indices"]),
            init_indices(case["tweak_indices"]),
            init_is_xonly(case),
            case["signer_index"],
            init_optional_expected(case),
        ),
    )

    s += init_cases(
        data["error_test_cases"],
        lambda case: "{ %s, %s, %s, { %s }, %d, { %s }},"
        % (
            init_indices(case["key_indices"]),
            init_indices(case["nonce_indices"]),
            init_indices(case["tweak_indices"]),
            init_is_xonly(case),
            case["signer_index"],
            init_optional_expected(case),
        ),
    )

    s += finish_init()

# sigagg vectors
with open(sys.argv[1] + "/sig_agg_vectors.json", "r") as f:
    data = json.load(f)

    num_pubkeys = len(data["pubkeys"])
    max_pubkeys = max(num_pubkeys, max_pubkeys)
    num_tweaks = len(data["tweaks"])
    num_psigs = len(data["psigs"])
    num_valid_cases = len(data["valid_test_cases"])
    num_error_cases = len(data["error_test_cases"])

    all_cases = data["valid_test_cases"] + data["error_test_cases"]
    max_key_indices = max(len(test_case["key_indices"]) for test_case in all_cases)
    max_tweak_indices = max(len(test_case["tweak_indices"]) for test_case in all_cases)
    max_psig_indices = max(len(test_case["psig_indices"]) for test_case in all_cases)

    # Add structures for valid and error cases
    s += """
/* Omit pubnonces in the test vectors because they're only needed for
 * implementations that do not directly accept an aggnonce. */
struct musig_sig_agg_case {
    size_t key_indices_len;
    size_t key_indices[%d];
    size_t tweak_indices_len;
    size_t tweak_indices[%d];
    int is_xonly[%d];
    unsigned char aggnonce[66];
    size_t psig_indices_len;
    size_t psig_indices[%d];
    /* if valid case */
    unsigned char expected[64];
    /* if error case */
    int invalid_sig_idx;
};
""" % (
        max_key_indices,
        max_tweak_indices,
        max_tweak_indices,
        max_psig_indices,
    )

    # Add structure for entire vector
    s += """
struct musig_sig_agg_vector {
    unsigned char pubkeys[%d][33];
    unsigned char tweaks[%d][32];
    unsigned char psigs[%d][32];
    unsigned char msg[32];
    struct musig_sig_agg_case valid_case[%d];
    struct musig_sig_agg_case error_case[%d];
};
""" % (
        num_pubkeys,
        num_tweaks,
        num_psigs,
        num_valid_cases,
        num_error_cases,
    )

    s += create_init("sig_agg")
    s += init_arrays("pubkeys")
    s += init_arrays("tweaks")
    s += init_arrays("psigs")
    s += init_array("msg")

    for cases in (data["valid_test_cases"], data["error_test_cases"]):
        s += init_cases(
            cases,
            lambda case: "{ %s, %s, { %s }, { %s }, %s, { %s }, %d },"
            % (
                init_indices(case["key_indices"]),
                init_indices(case["tweak_indices"]),
                init_is_xonly(case),
                hexstr_to_intarray(case["aggnonce"]),
                init_indices(case["psig_indices"]),
                init_optional_expected(case),
                case["error"]["signer"] if "error" in case else 0,
            ),
        )
    s += finish_init()
s += "enum { MUSIG_VECTORS_MAX_PUBKEYS = %d };" % max_pubkeys
print(s)
