"""
Testing CryptoBackendXMLSecurity with SoftHSM as backend
"""

# Command to convert test.pem to PKCS#8 PEM format readable by softhsm :
#
#  openssl pkcs8 -topk8 -inform PEM -outform PEM -in test.key -out test.key.p8
# -nocrypt
#

__author__ = "leifj"  # based on p11_test from pyXMLSecurity

import logging
import os
import subprocess
import tempfile
import traceback

from pathutils import full_path

from saml2 import class_name
from saml2 import saml
from saml2 import sigver
from saml2 import time_util
from saml2.s_utils import do_attribute_statement
from saml2.s_utils import factory


# xmlsec = pytest.importorskip("xmlsec")


def _find_alts(alts):
    for a in alts:
        if os.path.exists(a):
            return a
    return None


PUB_KEY = full_path("test.pem")
PRIV_KEY = full_path("test.key.p8")

P11_MODULES = ["/usr/lib/libsofthsm.so", "/usr/lib/softhsm/libsofthsm.so"]
P11_MODULE = _find_alts(P11_MODULES)
P11_ENGINE = "/usr/lib/engines/engine_pkcs11.so"


def _eq(l1, l2):
    return set(l1) == set(l2)


class FakeConfig:
    """
    Configuration parameters for signature validation test cases.
    """

    def __init__(self, pub_key=PUB_KEY):
        self.xmlsec_binary = None
        self.crypto_backend = "XMLSecurity"
        self.only_use_keys_in_metadata = False
        self.metadata = None
        self.cert_file = pub_key
        self.key_file = f"pkcs11://{P11_MODULE}:0/test?pin=secret1"
        self.debug = False
        self.cert_handler_extra_class = None
        self.generate_cert_info = False
        self.generate_cert_info = False
        self.tmp_cert_file = None
        self.tmp_key_file = None
        self.validate_certificate = False
        self.delete_tmpfiles = True


class TestPKCS11:
    def __init__(self):
        self.private_keyspec = None
        self.public_keyspec = None
        self.p11_test_files = []
        self.softhsm_conf = None
        self.softhsm_db = None
        self.configured = False
        self.sec = None
        self._assertion = None

    def setup_class(self):
        logging.debug("Creating test pkcs11 token using softhsm")
        try:
            self.softhsm_db = self._tf()
            self.softhsm_conf = self._tf()
            self.signer_cert_pem = self._tf()
            self.openssl_conf = self._tf()
            self.signer_cert_der = self._tf()

            logging.debug("Generating softhsm.conf")
            with open(self.softhsm_conf, "w") as f:
                f.write(f"#Generated by pysaml2 cryptobackend test\n0:{self.softhsm_db}\n")
            logging.debug("Initializing the token")
            self._p(
                ["softhsm", "--slot", "0", "--label", "test", "--init-token", "--pin", "secret1", "--so-pin", "secret2"]
            )

            logging.debug(f"Importing test key {PRIV_KEY!r} into SoftHSM")
            self._p(
                [
                    "softhsm",
                    "--slot",
                    "0",
                    "--label",
                    "test",
                    "--import",
                    PRIV_KEY,
                    "--id",
                    "a1b2",
                    "--pin",
                    "secret1",
                    "--so-pin",
                    "secret2",
                ]
            )

            logging.debug("Transforming PEM certificate to DER")
            self._p(
                ["openssl", "x509", "-inform", "PEM", "-outform", "DER", "-in", PUB_KEY, "-out", self.signer_cert_der]
            )

            logging.debug("Importing certificate into token")

            self._p(
                [
                    "pkcs11-tool",
                    "--module",
                    P11_MODULE,
                    "-l",
                    "--slot",
                    "0",
                    "--id",
                    "a1b2",
                    "--label",
                    "test",
                    "-y",
                    "cert",
                    "-w",
                    self.signer_cert_der,
                    "--pin",
                    "secret1",
                ]
            )

            # list contents of SoftHSM
            self._p(["pkcs11-tool", "--module", P11_MODULE, "-l", "--pin", "secret1", "-O"])
            self._p(["pkcs11-tool", "--module", P11_MODULE, "-l", "--pin", "secret1", "-T"])
            self._p(["pkcs11-tool", "--module", P11_MODULE, "-l", "--pin", "secret1", "-L"])
            self.sec = sigver.security_context(FakeConfig(pub_key=PUB_KEY))
            self._assertion = factory(
                saml.Assertion,
                version="2.0",
                id="11111",
                issue_instant="2009-10-30T13:20:28Z",
                signature=sigver.pre_signature_part("11111", self.sec.my_cert, 1),
                attribute_statement=do_attribute_statement(
                    {
                        ("", "", "surName"): ("Foo", ""),
                        ("", "", "givenName"): ("Bar", ""),
                    }
                ),
            )
            self.configured = True
        except Exception as ex:
            print("-" * 64)
            traceback.print_exc()
            print("-" * 64)
            logging.warning(f"PKCS11 tests disabled: unable to initialize test token: {ex}")
            raise

    def teardown_class(self):
        """
        Remove temporary files created in setup_class.
        """
        for o in self.p11_test_files:
            if os.path.exists(o):
                os.unlink(o)
        self.configured = False
        self.p11_test_files = []

    def _tf(self):
        f = tempfile.NamedTemporaryFile(delete=False)
        self.p11_test_files.append(f.name)
        return f.name

    def _p(self, args):
        env = {}
        if self.softhsm_conf is not None:
            env["SOFTHSM_CONF"] = self.softhsm_conf
            # print("env SOFTHSM_CONF=%s " % softhsm_conf +" ".join(args))
        logging.debug(f"Environment {env!r}")
        logging.debug(f"Executing {args!r}")
        args = ["ls"]
        proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
        out, err = proc.communicate()
        if err is not None and len(err) > 0:
            logging.error(err)
        if out is not None and len(out) > 0:
            logging.debug(out)
        rv = proc.wait()
        if rv:
            raise RuntimeError(f"command exited with code != 0: {int(rv)}")

    def test_SAML_sign_with_pkcs11(self):
        """
        Test signing a SAML assertion using PKCS#11 and then verifying it.
        """
        os.environ["SOFTHSM_CONF"] = self.softhsm_conf

        ass = self._assertion
        print(ass)
        sign_ass = self.sec.sign_assertion(f"{ass}", node_id=ass.id)
        # print(sign_ass)
        sass = saml.assertion_from_string(sign_ass)
        # print(sass)
        assert _eq(sass.keyswv(), ["attribute_statement", "issue_instant", "version", "signature", "id"])
        assert sass.version == "2.0"
        assert sass.id == "11111"
        assert time_util.str_to_time(sass.issue_instant)

        print(f"Crypto version : {self.sec.crypto.version()}")

        item = self.sec.check_signature(sass, class_name(sass), sign_ass)

        assert isinstance(item, saml.Assertion)

        print("Test PASSED")


def test_xmlsec_cryptobackend():
    t = TestPKCS11()
    t.setup_class()
    t.test_SAML_sign_with_pkcs11()


if __name__ == "__main__":
    test_xmlsec_cryptobackend()
