# Copyright: (c) 2020, Jordan Borean (@jborean93) <jborean93@gmail.com>
# MIT License (see LICENSE or https://opensource.org/licenses/MIT)

from __future__ import annotations

import base64
import collections
import dataclasses
import getpass
import io
import os
import pytest
import re
import requests
import sansldap
import socket
import spnego
import spnego.channel_bindings
import spnego.iov
import spnego.tls
import ssl
import struct
import subprocess
import sys
import tempfile
import typing as t
import uuid
import warnings

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from xml.etree import ElementTree as ET


USERNAME = '{{ domain_upn }}'
PASSWORD = '{{ domain_password }}'
HOSTNAME = socket.gethostname()
HOST_FQDN = '%s.{{ domain_name }}' % HOSTNAME
WIN_DC = '{{ groups['win_controller'][0] | lower }}.{{ domain_name | lower }}'
WIN_DC_IP = socket.gethostbyname(WIN_DC)
WIN_SERVER_UNTRUSTED = '{{ groups["win_children"][0] }}.{{ domain_name }}'  # Not trusted for delegation in AD
WIN_SERVER_TRUSTED = '{{ groups["win_children"][1] }}.{{ domain_name }}'  # Trusted for delegation in AD
WIN_SERVER_TRUSTED_IP = socket.gethostbyname(WIN_SERVER_TRUSTED)
KERBEROS_PROVIDER = '{{ krb_provider }}'

IS_SYSTEM = False
if os.name == 'nt':
    current_user = getpass.getuser()
    IS_SYSTEM = current_user == '%s$' % HOSTNAME

BIND_TIME_FEATURE_NEGOTIATION = (uuid.UUID("6cb71c2c-9812-4540-0300-000000000000"), 1, 0)
EMP = (uuid.UUID("e1af8308-5d1f-11c9-91a4-08002b14a0fa"), 3, 0)
ISD_KEY = (uuid.UUID("b9785960-524f-11df-8b6d-83dcded72085"), 1, 0)
NDR = (uuid.UUID("8a885d04-1ceb-11c9-9fe8-08002b104860"), 2, 0)
NDR64 = (uuid.UUID("71710533-beba-4937-8319-b5dbef9ccc36"), 1, 0)

SMBHeader = collections.namedtuple('SMBHeader', ['status', 'command', 'session_id', 'data'])
SMBNegotiateResponse = collections.namedtuple('SMBNegotiateResponse', ['dialect', 'buffer'])
SMBSessionResponse = collections.namedtuple('SMBSessionResponse', ['flags', 'buffer'])

WSMAN_NS = {
    "s": "http://www.w3.org/2003/05/soap-envelope",
    "xs": "http://www.w3.org/2001/XMLSchema",
    "xsi": "http://www.w3.org/2001/XMLSchema-instance",
    "wsa": "http://schemas.xmlsoap.org/ws/2004/08/addressing",
    "wsman": "http://schemas.dmtf.org/wbem/wsman/1/wsman.xsd",
    "wsmid": "http://schemas.dmtf.org/wbem/wsman/identify/1/wsmanidentity.xsd",
    "wsmanfault": "http://schemas.microsoft.com/wbem/wsman/1/wsmanfault",
    "cim": "http://schemas.dmtf.org/wbem/wscim/1/common",
    "wsmv": "http://schemas.microsoft.com/wbem/wsman/1/wsman.xsd",
    "cfg": "http://schemas.microsoft.com/wbem/wsman/1/config",
    "sub": "http://schemas.microsoft.com/wbem/wsman/1/subscription",
    "rsp": "http://schemas.microsoft.com/wbem/wsman/1/windows/shell",
    "m": "http://schemas.microsoft.com/wbem/wsman/1/machineid",
    "cert": "http://schemas.microsoft.com/wbem/wsman/1/config/service/certmapping",
    "plugin": "http://schemas.microsoft.com/wbem/wsman/1/config/PluginConfiguration",
    "wsen": "http://schemas.xmlsoap.org/ws/2004/09/enumeration",
    "wsdl": "http://schemas.xmlsoap.org/wsdl",
    "wst": "http://schemas.xmlsoap.org/ws/2004/09/transfer",
    "wsp": "http://schemas.xmlsoap.org/ws/2004/09/policy",
    "wse": "http://schemas.xmlsoap.org/ws/2004/08/eventing",
    "i": "http://schemas.microsoft.com/wbem/wsman/1/cim/interactive.xsd",
    "xml": "http://www.w3.org/XML/1998/namespace",
    "pwsh": "http://schemas.microsoft.com/powershell",
}


class HTTPWinRMAuth(requests.auth.AuthBase):

    def __init__(self, context):
        self.context = context
        self.header = None

        if context.protocol == 'credssp':
            self.valid_protocols = ['CredSSP']
        else:
            self.valid_protocols = ['Negotiate', 'Kerberos', 'NTLM']
        self._regex = re.compile(r'(%s)\s*([^,]*),?' % '|'.join(self.valid_protocols), re.I)

    def __call__(self, request):
        request.headers['Connection'] = 'Keep-Alive'
        request.register_hook('response', self.response_hook)

        return request

    def response_hook(self, response, **kwargs):
        if response.status_code == 401:
            response = self.handle_401(response, **kwargs)

        return response

    def handle_401(self, response, **kwargs):
        auth_supported = response.headers.get('www-authenticate', '')
        matched_protocols = [p for p in self.valid_protocols if p.upper() in auth_supported.upper()]
        if not matched_protocols:
            return response
        self.header = matched_protocols = matched_protocols[0]

        out_token = self.context.step()
        while not self.context.complete or out_token is not None:
            response.content
            response.raw.release_conn()
            request = response.request.copy()

            auth_header = b'%s %s' % (self.header.encode(), base64.b64encode(out_token))
            request.headers['Authorization'] = auth_header
            response = response.connection.send(request, **kwargs)

            in_token = None
            auth_header = response.headers.get('www-authenticate', '')
            token_match = self._regex.search(auth_header)
            if token_match:
                in_token = token_match.group(2)

            if not in_token:
                break

            out_token = self.context.step(base64.b64decode(in_token))

        return response


@dataclasses.dataclass(frozen=True)
class SecTrailer:
    type: int
    level: int
    pad_length: int
    context_id: int
    data: bytes


@dataclasses.dataclass(frozen=True)
class Tower:
    service: t.Tuple[uuid.UUID, int, int]
    data_rep: t.Tuple[uuid.UUID, int, int]
    protocol: int
    port: int
    addr: int


@dataclasses.dataclass
class GetKeyRequest:
    opnum: int = dataclasses.field(init=False, repr=False, default=0)

    target_sd: bytes
    root_key_id: t.Optional[uuid.UUID]
    l0_key_id: int
    l1_key_id: int
    l2_key_id: int

    # MS-GKDI 3.1.4.1 GetKey (Opnum 0)
    # https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gkdi/4cac87a3-521e-4918-a272-240f8fabed39
    # HRESULT GetKey(
    #     [in] handle_t hBinding,
    #     [in] ULONG cbTargetSD,
    #     [in] [size_is(cbTargetSD)] [ref] char* pbTargetSD,
    #     [in] [unique] GUID* pRootKeyID,
    #     [in] LONG L0KeyID,
    #     [in] LONG L1KeyID,
    #     [in] LONG L2KeyID,
    #     [out] unsigned long* pcbOut,
    #     [out] [size_is(, *pcbOut)] byte** ppbOut);

    def pack(self) -> bytes:
        # Strictly speaking it is only 4 bytes but NDR64 needs 8 byte alignment
        # on the field after.
        target_sd_len = len(self.target_sd).to_bytes(8, byteorder="little")
        root_key_id = b"\x00" * 8
        if self.root_key_id:
            root_key_id = b"\x00\x00\x02\x00\x00\x00\x00\x00" + self.root_key_id.bytes_le

        return b"".join(
            [
                # cbTargetSD
                target_sd_len,
                # pbTargetSD - pointer header includes the length + padding
                target_sd_len,
                self.target_sd,
                b"\x00" * (-len(self.target_sd) % 8),
                # pRootKeyID
                root_key_id,
                # L0KeyID
                self.l0_key_id.to_bytes(4, byteorder="little", signed=True),
                # L1KeyID
                self.l1_key_id.to_bytes(4, byteorder="little", signed=True),
                # L2KeyID
                self.l2_key_id.to_bytes(4, byteorder="little", signed=True),
            ]
        )

    @classmethod
    def unpack_response(
        cls,
        data: t.Union[bytes, bytearray, memoryview],
    ) -> GroupKeyEnvelope:
        view = memoryview(data)

        hresult = struct.unpack("<I", view[-4:].tobytes())[0]
        view = view[:-4]
        if hresult != 0:
            raise Exception(f"GetKey failed 0x{hresult:08X}")

        key_length = struct.unpack("<I", view[:4])[0]
        view = view[8:]  # Skip padding as well
        # Skip the referent id and double up on pointer size
        key = view[16 : 16 + key_length].tobytes()
        assert len(key) == key_length
        # print(f"GetKey Response: {base64.b16encode(data).decode()}")

        return GroupKeyEnvelope.unpack(key)


@dataclasses.dataclass(frozen=True)
class GroupKeyEnvelope:
    # https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-GKDI/%5bMS-GKDI%5d.pdf
    # 2.2.4 Group Key Envelope
    version: int
    flags: int
    l0: int
    l1: int
    l2: int
    root_key_identifier: uuid.UUID
    kdf_algorithm: str
    kdf_parameters: bytes
    secret_algorithm: str
    secret_parameters: bytes
    private_key_length: int
    public_key_length: int
    domain_name: str
    forest_name: str
    l1_key: bytes
    l2_key: bytes

    @property
    def is_public_key(self) -> bool:
        return bool(self.flags & 1)

    @classmethod
    def unpack(
        cls,
        data: t.Union[bytes, bytearray, memoryview],
    ) -> GroupKeyEnvelope:
        view = memoryview(data)

        version = struct.unpack("<I", view[:4])[0]

        assert view[4:8].tobytes() == b"\x4B\x44\x53\x4B"

        flags = struct.unpack("<I", view[8:12])[0]
        l0_index = struct.unpack("<I", view[12:16])[0]
        l1_index = struct.unpack("<I", view[16:20])[0]
        l2_index = struct.unpack("<I", view[20:24])[0]
        root_key_identifier = uuid.UUID(bytes_le=view[24:40].tobytes())
        kdf_algo_len = struct.unpack("<I", view[40:44])[0]
        kdf_para_len = struct.unpack("<I", view[44:48])[0]
        sec_algo_len = struct.unpack("<I", view[48:52])[0]
        sec_para_len = struct.unpack("<I", view[52:56])[0]
        priv_key_len = struct.unpack("<I", view[56:60])[0]
        publ_key_len = struct.unpack("<I", view[60:64])[0]
        l1_key_len = struct.unpack("<I", view[64:68])[0]
        l2_key_len = struct.unpack("<I", view[68:72])[0]
        domain_len = struct.unpack("<I", view[72:76])[0]
        forest_len = struct.unpack("<I", view[76:80])[0]
        view = view[80:]

        kdf_algo = view[: kdf_algo_len - 2].tobytes().decode("utf-16-le")
        view = view[kdf_algo_len:]

        kdf_param = view[:kdf_para_len].tobytes()
        view = view[kdf_para_len:]

        secret_algo = view[: sec_algo_len - 2].tobytes().decode("utf-16-le")
        view = view[sec_algo_len:]

        secret_param = view[:sec_para_len].tobytes()
        view = view[sec_para_len:]

        domain = view[: domain_len - 2].tobytes().decode("utf-16-le")
        view = view[domain_len:]

        forest = view[: forest_len - 2].tobytes().decode("utf-16-le")
        view = view[forest_len:]

        l1_key = view[:l1_key_len].tobytes()
        view = view[l1_key_len:]

        l2_key = view[:l2_key_len].tobytes()
        view = view[l2_key_len:]

        return GroupKeyEnvelope(
            version=version,
            flags=flags,
            l0=l0_index,
            l1=l1_index,
            l2=l2_index,
            root_key_identifier=root_key_identifier,
            kdf_algorithm=kdf_algo,
            kdf_parameters=kdf_param,
            secret_algorithm=secret_algo,
            secret_parameters=secret_param,
            private_key_length=priv_key_len,
            public_key_length=publ_key_len,
            domain_name=domain,
            forest_name=forest,
            l1_key=l1_key,
            l2_key=l2_key,
        )


def get_cbt_data(server):
    cert_pem = ssl.get_server_certificate((server, 5986))
    b_cert = ssl.PEM_cert_to_DER_cert(cert_pem)
    cert = x509.load_der_x509_certificate(b_cert, default_backend())

    hash_algorithm = cert.signature_hash_algorithm

    if hash_algorithm.name in ['md5', 'sha1']:
        digest = hashes.Hash(hashes.SHA256(), default_backend())
    else:
        digest = hashes.Hash(hash_algorithm, default_backend())

    digest.update(b_cert)
    b_cert_hash = digest.finalize()

    b_app_data = b"tls-server-end-point:" + b_cert_hash

    return spnego.channel_bindings.GssChannelBindings(application_data=b_app_data)


def smb_send(smb_socket, message_id, command, data, session_id=0, tree_id=0):
    header = b"\xFESMB"  # ProtocolId
    header += struct.pack("<H", 64)  # StructureSize
    header += struct.pack("<H", 1)  # CreditCharge
    header += b"\x00\x00\x00\x00"  # Status
    header += struct.pack("<H", command)  # Command
    header += struct.pack("<H", 1)  # CreditRequest
    header += struct.pack("<I", 0)  # Flags
    header += struct.pack("<I", 0)  # NextCommand
    header += struct.pack("<Q", message_id)  # MessageId
    header += b"\x00\x00\x00\x00"  # Reserved
    header += struct.pack("<I", tree_id)  # TreeId
    header += struct.pack("<Q", session_id)  # SessionId
    header += b"\x00" * 16  # Signature
    header += data

    payload = struct.pack(">I", len(header)) + header
    smb_socket.send(payload)

    payload_len = struct.unpack(">I", smb_socket.recv(4))[0]
    payload = smb_socket.recv(payload_len)

    status = struct.unpack("<I", payload[8:12])[0]
    command = struct.unpack("<H", payload[12:14])[0]
    session_id = struct.unpack("<Q", payload[40:48])[0]
    data = payload[64:]

    if status != 0:
        raise Exception("SMB Exception: 0x{:02X}".format(status))

    if command == 0:
        data = smb_unpack_negotiate(data)

    elif command == 1:
        data = smb_unpack_session(data)

    return SMBHeader(status=status, command=command, session_id=session_id, data=data)


def smb_negotiate_request(smb_socket, message_id):
    negotiate = struct.pack("<H", 36)  # StructureSize
    negotiate += struct.pack("<H", 1)  # DialectCount
    negotiate += struct.pack("<H", 1)  # SecurityMode - SMB2_NEGOTIATE_SIGNING_ENABLED
    negotiate += b"\x00\x00"  # Reserved
    negotiate += struct.pack("<I", 0x00000040)  # Capabilities - SMB2_GLOBAL_CAP_ENCRYPTION
    negotiate += uuid.uuid4().bytes  # ClientGuid
    negotiate += b"\x00\x00\x00\x00\x00\x00\x00\x00"
    negotiate += struct.pack("<H", 0x0302)  # Dialects - [SMB 3.0.2]

    return smb_send(smb_socket, message_id, 0x0000, negotiate)


def smb_unpack_negotiate(data):
    dialect_revision = struct.unpack("<H", data[4:6])[0]
    buffer_offset = struct.unpack("<H", data[56:58])[0] - 64
    buffer_length = struct.unpack("<H", data[58:60])[0]
    if buffer_length:
        sec_buffer = data[buffer_offset:buffer_offset + buffer_length]

    else:
        sec_buffer = None

    return SMBNegotiateResponse(dialect=dialect_revision, buffer=sec_buffer)


def smb_session_setup(smb_socket, message_id, token):
    setup = struct.pack("<H", 25)  # StructureSize
    setup += b"\x00"  # Flags
    setup += struct.pack("B", 0x01)  # SecurityMode - SMB2_NEGOTIATE_SIGNING_ENABLED
    setup += b"\x00\x00\x00\x00"  # Capabilities
    setup += b"\x00\x00\x00\x00"  # Channel
    setup += struct.pack("<H", 64 + 24)  # SecurityBufferOffset
    setup += struct.pack("<H", len(token))  # SecurityBufferLength
    setup += b"\x00\x00\x00\x00\x00\x00\x00\x00"  # PreviousSessionId
    setup += token  # Buffer

    return smb_send(smb_socket, message_id, 0x0001, setup)


def smb_unpack_session(data):
    flags = struct.unpack("<H", data[2:4])[0]
    buffer_offset = struct.unpack("<H", data[4:6])[0] - 64
    buffer_length = struct.unpack("<H", data[6:8])[0]
    if buffer_length:
        sec_buffer = data[buffer_offset:buffer_offset + buffer_length]

    else:
        sec_buffer = None

    return SMBSessionResponse(flags=flags, buffer=sec_buffer)


def winrm_run(context, expected_header, server, command, arguments=None, ssl=False):
    http = requests.Session()
    http.auth = HTTPWinRMAuth(context)
    warnings.simplefilter('ignore', category=InsecureRequestWarning)
    http.headers = {
        'User-Agent': 'pyspnego_client',
    }

    if ssl:
        endpoint = 'https://%s:5986/wsman' % server
        http.verify = False

    else:
        endpoint = 'http://%s:5985/wsman' % server

        # We need to ensure we have set up the context already so we can start encrypting the data.
        request = requests.Request('POST', endpoint, data=None)
        prep_request = http.prepare_request(request)
        response = http.send(prep_request)
        response.raise_for_status()

    setattr(http, 'endpoint', endpoint)
    setattr(http, 'session_id', str(uuid.uuid4()).upper())

    shell_id = wsman_create(http)
    try:
        cmd_id = wsman_command(http, shell_id, command, arguments)
        rc, stdout, stderr = wsman_receive(http, shell_id, cmd_id)
        wsman_signal(http, shell_id, cmd_id,
                     'http://schemas.microsoft.com/wbem/wsman/1/windows/shell/signal/Terminate')

        if stderr.startswith('#< CLIXML'):
            # Strip off the '#< CLIXML\r\n' by finding the 2nd index of '<'
            output = stderr[stderr.index('<', 2):]
            element = ET.fromstring(output)
            namespace = element.tag.replace("Objs", "")[1:-1]

            errors = []
            for error in element.findall("{%s}S[@S='Error']" % namespace):
                errors.append(error.text)

            stderr = "".join(errors).replace('_x000D_', '\r').replace('_x000A_', '\n')

        # Make sure the protocol authentication protocol was actually tested.
        assert http.auth.header == expected_header

        return rc, stdout, stderr

    finally:
        wsman_delete(http, shell_id)


def wsman_command(http, shell_id, command, arguments=None):
    rsp = WSMAN_NS['rsp']

    command_line = ET.Element('{%s}CommandLine' % rsp)
    ET.SubElement(command_line, '{%s}Command' % rsp).text = command
    for argument in arguments or []:
        ET.SubElement(command_line, '{%s}Arguments' % rsp).text = argument

    command_response = wsman_envelope('http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Command', http,
                                      body=command_line, selector_set={'ShellId': shell_id},
                                      option_set={'WINRS_SKIP_CMD_SHELL': False})

    return command_response.find('s:Body/rsp:CommandResponse/rsp:CommandId', WSMAN_NS).text


def wsman_create(http):
    rsp = WSMAN_NS['rsp']

    shell = ET.Element('{%s}Shell' % rsp)
    ET.SubElement(shell, '{%s}InputStreams' % rsp).text = 'stdin'
    ET.SubElement(shell, '{%s}OutputStreams' % rsp).text = 'stdout stderr'

    create_response = wsman_envelope('http://schemas.xmlsoap.org/ws/2004/09/transfer/Create', http, body=shell,
                                     option_set={'WINRS_CODEPAGE': 65001})
    return create_response.find('s:Body/rsp:Shell/rsp:ShellId', WSMAN_NS).text


def wsman_delete(http, shell_id):
    wsman_envelope('http://schemas.xmlsoap.org/ws/2004/09/transfer/Delete', http, selector_set={'ShellId': shell_id})


def wsman_receive(http, shell_id, command_id):
    rsp = WSMAN_NS['rsp']

    out = {
        'stdout': io.BytesIO(),
        'stderr': io.BytesIO(),
    }
    while True:
        receive = ET.Element('{%s}Receive' % rsp)
        ET.SubElement(receive, '{%s}DesiredStream' % rsp, attrib={'CommandId': command_id}).text = 'stdout stderr'

        receive_response = wsman_envelope('http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Receive', http,
                                          body=receive, selector_set={'ShellId': shell_id},
                                          option_set={'WSMAN_CMDSHELL_OPTION_KEEPALIVE': True})

        streams = receive_response.findall('s:Body/rsp:ReceiveResponse/rsp:Stream', WSMAN_NS)
        for stream in streams:
            if stream.text:
                b_data = base64.b64decode(stream.text)
                out[stream.attrib['Name']].write(b_data)

        state = receive_response.find('s:Body/rsp:ReceiveResponse/rsp:CommandState', WSMAN_NS)
        if state.attrib['State'].endswith('Done'):
            rc = int(state.find('rsp:ExitCode', WSMAN_NS).text)
            break

    return rc, out['stdout'].getvalue().decode('utf-8'), out['stderr'].getvalue().decode('utf-8')


def wsman_signal(http, shell_id, command_id, code):
    rsp = WSMAN_NS['rsp']

    signal = ET.Element('{%s}Signal' % rsp, attrib={'CommandId': command_id})
    ET.SubElement(signal, '{%s}Code' % rsp).text = code
    wsman_envelope('http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Signal', http, body=signal,
                   selector_set={'ShellId': shell_id})


def wsman_envelope(action, http, selector_set=None, option_set=None, body=None):
    s = WSMAN_NS['s']
    wsa = WSMAN_NS['wsa']
    wsman = WSMAN_NS['wsman']
    wsmv = WSMAN_NS['wsmv']
    xml = WSMAN_NS['xml']
    understand = '{%s}mustUnderstand' % s

    envelope = ET.Element('{%s}Envelope' % s)
    header = ET.SubElement(envelope, '{%s}Header' % WSMAN_NS['s'])

    ET.SubElement(header, '{%s}Action' % wsa, attrib={understand: 'true'}).text = action
    ET.SubElement(header, '{%s}SessionId' % wsmv, attrib={understand: 'false'}).text = 'uuid:%s' % http.session_id.upper()
    ET.SubElement(header, '{%s}To' % wsa).text = http.endpoint
    ET.SubElement(header, '{%s}MaxEnvelopeSize' % wsman, attrib={understand: 'true'}).text = '153600'
    ET.SubElement(header, '{%s}MessageID' % wsa).text = 'uuid:%s' % str(uuid.uuid4()).upper()
    ET.SubElement(header, '{%s}OperationTimeout' % wsman).text = 'PT30S'

    reply_to = ET.SubElement(header, '{%s}ReplyTo' % wsa)
    ET.SubElement(
        reply_to, '{%s}Address' % wsa, attrib={understand: "true"}
    ).text = 'http://schemas.xmlsoap.org/ws/2004/08/addressing/role/anonymous'

    ET.SubElement(
        header, '{%s}ResourceURI' % wsman, attrib={understand: 'true'}
    ).text = 'http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd'

    for e in ['DataLocale', 'Locale']:
        ET.SubElement(
            header, '{%s}%s' % (wsmv, e), attrib={understand: 'false', '{%s}lang' % xml: 'en-US'}
        )

    for set_value, name, option_name in [(selector_set, 'SelectorSet', 'Selector'),
                                         (option_set, 'OptionSet', 'Option')]:
        if not set_value:
            continue

        set_element = ET.SubElement(header, '{%s}%s' % (wsman, name))
        if name == 'OptionSet':
            set_element.attrib = {understand: 'true'}

        for key, value in set_value.items():
            ET.SubElement(set_element, '{%s}%s' % (wsman, option_name), Name=key).text = str(value)

    envelope_body = ET.SubElement(envelope, '{%s}Body' % s)
    if body is not None:
        envelope_body.append(body)

    content = ET.tostring(envelope, encoding='utf-8', method='xml')
    boundary = 'Encrypted Boundary'

    if http.endpoint.startswith('http://'):
        auth_protocol = 'CredSSP' if http.auth.context.protocol == 'credssp' else 'SPNEGO'
        protocol = 'application/HTTP-%s-session-encrypted' % auth_protocol

        max_size = 16384 if auth_protocol == 'CredSSP' else len(content)
        chunks = [content[i:i + max_size] for i in range(0, len(content), max_size)]
        encrypted_chunks = []
        for chunk in chunks:
            header, wrapped_data, padding_length = http.auth.context.wrap_winrm(chunk)
            wrapped_data = struct.pack("<i", len(header)) + header + wrapped_data
            msg_length = str(len(content) + padding_length)

            content = "\r\n".join([
                '--%s' % boundary,
                '\tContent-Type: %s' % protocol,
                '\tOriginalContent: type=application/soap+xml;charset=UTF-8;Length=%s' % msg_length,
                '--%s' % boundary,
                '\tContent-Type: application/octet-stream',
                '',
            ])
            encrypted_chunks.append(content.encode() + wrapped_data)

        content_sub_type = 'multipart/encrypted' if len(encrypted_chunks) == 1 \
            else 'multipart/x-multi-encrypted'
        content_type = '%s;protocol="%s";boundary="%s"' % (content_sub_type, protocol, boundary)
        content = b"".join(encrypted_chunks) + ("--%s--\r\n" % boundary).encode()

    else:
        content_type = 'application/soap+xml;charset=UTF-8'

    headers = {
        'Content-Length': str(len(content)),
        'Content-Type': content_type,
    }

    request = http.prepare_request(requests.Request('POST', http.endpoint, data=content, headers=headers))
    response = http.send(request)
    response.raise_for_status()
    content = response.content

    content_type = response.headers.get('content-type', '')
    if content_type.startswith('multipart/encrypted;') or content_type.startswith('multipart/x-multi-encrypted;'):
        boundary = re.search('boundary=[''|\\"](.*)[''|\\"]', content_type).group(1)
        parts = re.compile((r"--\s*%s\r\n" % re.escape(boundary)).encode()).split(content)
        parts = list(filter(None, parts))

        content = b""
        for i in range(0, len(parts), 2):
            header = parts[i].strip()
            payload = parts[i + 1]

            expected_length = int(header.split(b"Length=")[1])

            # remove the end MIME block if it exists
            payload = re.sub((r'--\s*%s--\r\n$' % boundary).encode(), b'', payload)

            wrapped_data = payload.replace(b"\tContent-Type: application/octet-stream\r\n", b"")
            header_length = struct.unpack("<i", wrapped_data[:4])[0]
            header = wrapped_data[4:4 + header_length]
            wrapped_data = wrapped_data[4 + header_length:]

            unwrapped_data = http.auth.context.unwrap_winrm(header, wrapped_data)
            assert len(unwrapped_data) == expected_length
            content += unwrapped_data

    return ET.fromstring(content)


def create_pdu(
    packet_type: int,
    packet_flags: int,
    call_id: int,
    header_data: bytes,
    *,
    stub_data: t.Optional[bytes] = None,
    sec_trailer: t.Optional[SecTrailer] = None,
) -> bytes:
    # https://pubs.opengroup.org/onlinepubs/9629399/toc.pdf
    # 12.6.3 Connection-oriented PDU Data Types - PDU Header
    data = bytearray()
    data += struct.pack("B", 5)  # Version
    data += struct.pack("B", 0)  # Version minor
    data += struct.pack("B", packet_type)
    data += struct.pack("B", packet_flags)
    data += b"\x10\x00\x00\x00"  # Data Representation
    data += b"\x00\x00"  # Fragment length - set at the end below
    data += struct.pack("<H", len(sec_trailer.data) if sec_trailer else 0)
    data += struct.pack("<I", call_id)
    data += header_data
    data += stub_data or b""

    if sec_trailer:
        data += struct.pack("B", sec_trailer.type)
        data += struct.pack("B", sec_trailer.level)
        data += struct.pack("B", sec_trailer.pad_length)
        data += struct.pack("B", 0)  # Auth Rsrvd
        data += struct.pack("<I", sec_trailer.context_id)
        data += sec_trailer.data

    memoryview(data)[8:10] = struct.pack("<H", len(data))

    return bytes(data)


def create_bind(
    service: t.Tuple[uuid.UUID, int, int],
    syntaxes: t.List[bytes],
    auth_data: t.Optional[bytes] = None,
    sign_header: bool = False,
    auth_level: int = 6,
    auth_type: int = 9,
) -> bytes:
    context_header = b"\x00\x00\x01\x00"
    context_header += service[0].bytes_le
    context_header += struct.pack("<H", service[1])
    context_header += struct.pack("<H", service[2])
    context_data = bytearray()
    for idx, s in enumerate(syntaxes):
        offset = len(context_data)
        context_data += context_header
        memoryview(context_data)[offset : offset + 2] = struct.pack("<H", idx)
        context_data += s

    bind_data = bytearray()
    bind_data += b"\xd0\x16"  # Max Xmit Frag
    bind_data += b"\xd0\x16"  # Max Recv Frag
    bind_data += b"\x00\x00\x00\x00"  # Assoc Group
    bind_data += b"\x03\x00\x00\x00"  # Num context items
    bind_data += context_data

    sec_trailer: t.Optional[SecTrailer] = None
    if auth_data:
        sec_trailer = SecTrailer(
            type=auth_type,
            level=auth_level,
            pad_length=0,
            context_id=0,
            data=auth_data,
        )

    return create_pdu(
        packet_type=11,
        packet_flags=0x03 | (0x4 if sign_header else 0x0),
        call_id=1,
        header_data=bytes(bind_data),
        sec_trailer=sec_trailer,
    )


def create_alter_context(
    service: t.Tuple[uuid.UUID, int, int],
    token: bytes,
    sign_header: bool = False,
    auth_level: int = 6,
    auth_type: int = 9,
) -> bytes:
    ctx1 = b"\x01\x00\x01\x00"
    ctx1 += service[0].bytes_le
    ctx1 += struct.pack("<H", service[1])
    ctx1 += struct.pack("<H", service[1])
    ctx1 += NDR64[0].bytes_le + struct.pack("<H", NDR64[1]) + struct.pack("<H", NDR[2])

    alter_context_data = bytearray()
    alter_context_data += b"\xd0\x16"  # Max Xmit Frag
    alter_context_data += b"\xd0\x16"  # Max Recv Frag
    alter_context_data += b"\x00\x00\x00\x00"  # Assoc Group
    alter_context_data += b"\x01\x00\x00\x00"  # Num context items
    alter_context_data += ctx1

    auth_data = SecTrailer(
        type=auth_type,
        level=auth_level,
        pad_length=0,
        context_id=0,
        data=token,
    )

    return create_pdu(
        packet_type=14,
        packet_flags=0x03 | (0x4 if sign_header else 0x0),
        call_id=1,
        header_data=bytes(alter_context_data),
        sec_trailer=auth_data,
    )


def create_request(
    opnum: int,
    data: bytes,
    ctx: t.Optional[spnego.ContextProxy] = None,
    sign_header: bool = False,
    auth_level: int = 6,
    auth_type: int = 9,
) -> bytes:
    # Add Verification trailer to data
    # MS-RPCE 2.2.2.13 Veritifcation Trailer
    # https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rpce/0e9fea61-1bff-4478-9bfe-a3b6d8b64ac3
    if ctx:
        pcontext = bytearray()
        pcontext += ISD_KEY[0].bytes_le
        pcontext += struct.pack("<H", ISD_KEY[1]) + struct.pack("<H", ISD_KEY[2])
        pcontext += NDR64[0].bytes_le
        pcontext += struct.pack("<H", NDR64[1]) + struct.pack("<H", NDR64[2])

        verification_trailer = bytearray()
        verification_trailer += b"\x8a\xe3\x13\x71\x02\xf4\x36\x71"  # Signature

        verification_trailer += b"\x02\x40"  # Trailer Command - PCONTEXT + End
        verification_trailer += struct.pack("<H", len(pcontext))
        verification_trailer += pcontext

        # Verification trailer to added to a 4 byte boundary on the stub data
        data_padding = -len(data) % 4
        data += b"\x00" * data_padding

        data += verification_trailer
        alloc_hint = len(data)
        auth_padding = -len(data) % 16
        data += b"\x00" * auth_padding

    else:
        alloc_hint = len(data)

    request_data = bytearray()
    request_data += struct.pack("<I", alloc_hint)
    request_data += struct.pack("<H", 1)  # Context id
    request_data += struct.pack("<H", opnum)

    if ctx:
        header_length = ctx.query_message_sizes().header

        sec_trailer = SecTrailer(
            type=auth_type,
            level=auth_level,
            pad_length=auth_padding,
            context_id=0,
            data=b"\x00" * header_length,
        )
        pdu_req = bytearray(
            create_pdu(
                packet_type=0,
                packet_flags=0x03,
                call_id=1,
                header_data=bytes(request_data),
                stub_data=data,
                sec_trailer=sec_trailer,
            )
        )

        sign_type = spnego.iov.BufferType.sign_only if sign_header else spnego.iov.BufferType.data_readonly
        sec_trailer_data = pdu_req[-(header_length + 8) : -header_length]
        res = ctx.wrap_iov(
            [
                (sign_type, bytes(pdu_req[:24])),
                data,
                (sign_type, bytes(sec_trailer_data)),
                spnego.iov.BufferType.header,
            ],
            encrypt=True,
            qop=None,
        )

        enc_data = res.buffers[1].data
        sig = res.buffers[3].data

        data_view = memoryview(pdu_req)
        data_view[24 : 24 + len(data)] = enc_data
        data_view[-header_length:] = sig

        return bytes(pdu_req)

    else:
        return create_pdu(
            packet_type=0,
            packet_flags=0x03,
            call_id=1,
            header_data=bytes(request_data),
            stub_data=data,
        )


def get_fault_pdu_error(data: memoryview) -> int:
    status = struct.unpack("<I", data[24:28])[0]

    return status


def parse_bind_ack(data: bytes) -> t.Optional[bytes]:
    view = memoryview(data)

    pkt_type = struct.unpack("B", view[2:3])[0]
    if pkt_type == 3:
        err = get_fault_pdu_error(view)
        raise Exception(f"Receive Fault PDU: 0x{err:08X}")

    assert pkt_type == 12

    auth_length = struct.unpack("<H", view[10:12])[0]
    if auth_length:
        auth_blob = view[-auth_length:].tobytes()

        return auth_blob

    else:
        return None


def parse_alter_context(data: bytes) -> bytes:
    view = memoryview(data)

    pkt_type = struct.unpack("B", view[2:3])[0]
    if pkt_type == 3:
        err = get_fault_pdu_error(view)
        raise Exception(f"Receive Fault PDU: 0x{err:08X}")

    assert pkt_type == 15

    auth_length = struct.unpack("<H", view[10:12])[0]
    if auth_length:
        auth_blob = view[-auth_length:].tobytes()

        return auth_blob
    else:
        return b""


def parse_response(
    data: bytes,
    ctx: t.Optional[spnego.ContextProxy] = None,
    sign_header: bool = False,
) -> bytes:
    view = memoryview(data)

    pkt_type = struct.unpack("B", view[2:3])[0]
    if pkt_type == 3:  # False
        err = get_fault_pdu_error(view)
        raise Exception(f"Receive Fault PDU: 0x{err:08X}")

    assert pkt_type == 2
    frag_length = struct.unpack("<H", view[8:10])[0]
    auth_length = struct.unpack("<H", view[10:12])[0]

    assert len(view) == frag_length
    if auth_length:
        auth_data = view[-(auth_length + 8) :]
        stub_data = view[24 : len(view) - (auth_length + 8)]
        padding = struct.unpack("B", auth_data[2:3])[0]

    else:
        auth_data = memoryview(b"")
        stub_data = view[24:]
        padding = 0

    if ctx:
        sign_type = spnego.iov.BufferType.sign_only if sign_header else spnego.iov.BufferType.data_readonly
        res = ctx.unwrap_iov(
            [
                (sign_type, data[:24]),
                stub_data.tobytes(),
                (sign_type, auth_data[:8].tobytes()),
                (spnego.iov.BufferType.header, auth_data[8:].tobytes()),
            ],
        )
        decrypted_stub = res.buffers[1].data

        return decrypted_stub[: len(decrypted_stub) - padding]

    else:
        return stub_data.tobytes()


def create_ept_map_request(
    service: t.Tuple[uuid.UUID, int, int],
    data_rep: t.Tuple[uuid.UUID, int, int],
    protocol: int = 0x0B,  # TCP/IP
    port: int = 135,
    address: int = 0,
) -> t.Tuple[int, bytes]:
    # MS-RPCE 2.2.1.2.5 ept_map Method
    # https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rpce/ab744583-430e-4055-8901-3c6bc007e791
    # void ept_map(
    #     [in] handle_t hEpMapper,
    #     [in, ptr] UUID* obj,
    #     [in, ptr] twr_p_t map_tower,
    #     [in, out] ept_lookup_handle_t* entry_handle,
    #     [in, range(0,500)] unsigned long max_towers,
    #     [out] unsigned long* num_towers,
    #     [out, ptr, size_is(max_towers), length_is(*num_towers)]
    #         twr_p_t* ITowers,
    #     [out] error_status* status
    # );
    def build_floor(protocol: int, lhs: bytes, rhs: bytes) -> bytes:
        data = bytearray()
        data += struct.pack("<H", len(lhs) + 1)
        data += struct.pack("B", protocol)
        data += lhs
        data += struct.pack("<H", len(rhs))
        data += rhs

        return bytes(data)

    floors: t.List[bytes] = [
        build_floor(
            protocol=0x0D,
            lhs=service[0].bytes_le + struct.pack("<H", service[1]),
            rhs=struct.pack("<H", service[2]),
        ),
        build_floor(
            protocol=0x0D,
            lhs=data_rep[0].bytes_le + struct.pack("<H", data_rep[1]),
            rhs=struct.pack("<H", data_rep[2]),
        ),
        build_floor(protocol=protocol, lhs=b"", rhs=b"\x00\x00"),
        build_floor(protocol=0x07, lhs=b"", rhs=struct.pack(">H", port)),
        build_floor(protocol=0x09, lhs=b"", rhs=struct.pack(">I", address)),
    ]

    tower = bytearray()
    tower += struct.pack("<H", len(floors))
    for f in floors:
        tower += f
    tower_padding = -(len(tower) + 4) % 8

    data = bytearray()
    data += b"\x01" + (b"\x00" * 23)  # Blank UUID pointer with referent id 1
    data += b"\x02\x00\x00\x00\x00\x00\x00\x00"  # Tower referent id 2
    data += struct.pack("<Q", len(tower))
    data += struct.pack("<I", len(tower))
    data += tower
    data += b"\x00" * tower_padding
    data += b"\x00" * 20  # Context handle
    data += struct.pack("<I", 4)  # Max towers

    return 3, bytes(data)


def parse_ept_map_response(data: bytes) -> t.List[Tower]:
    def unpack_floor(view: memoryview) -> t.Tuple[int, int, memoryview, memoryview]:
        lhs_len = struct.unpack("<H", view[:2])[0]
        proto = view[2]
        lhs = view[3 : lhs_len + 2]
        offset = lhs_len + 2

        rhs_len = struct.unpack("<H", view[offset : offset + 2])[0]
        rhs = view[offset + 2 : offset + rhs_len + 2]

        return offset + rhs_len + 2, proto, lhs, rhs

    view = memoryview(data)

    return_code = struct.unpack("<I", view[-4:])[0]
    assert return_code == 0
    num_towers = struct.unpack("<I", view[20:24])[0]
    # tower_max_count = struct.unpack("<Q", view[24:32])[0]
    # tower_offset = struct.unpack("<Q", view[32:40])[0]
    tower_count = struct.unpack("<Q", view[40:48])[0]

    tower_data_offset = 8 * tower_count  # Ignore referent ids
    view = view[48 + tower_data_offset :]
    towers: t.List[Tower] = []
    for _ in range(tower_count):
        tower_length = struct.unpack("<Q", view[:8])[0]
        padding = -(tower_length + 4) % 8
        floor_len = struct.unpack("<H", view[12:14])[0]
        assert floor_len == 5
        view = view[14:]

        offset, proto, lhs, rhs = unpack_floor(view)
        view = view[offset:]
        assert proto == 0x0D
        service = (
            uuid.UUID(bytes_le=lhs[:16].tobytes()),
            struct.unpack("<H", lhs[16:])[0],
            struct.unpack("<H", rhs)[0],
        )

        offset, proto, lhs, rhs = unpack_floor(view)
        view = view[offset:]
        assert proto == 0x0D
        data_rep = (
            uuid.UUID(bytes_le=lhs[:16].tobytes()),
            struct.unpack("<H", lhs[16:])[0],
            struct.unpack("<H", rhs)[0],
        )

        offset, protocol, _, _ = unpack_floor(view)
        view = view[offset:]
        assert protocol == 0x0B

        offset, proto, lhs, rhs = unpack_floor(view)
        view = view[offset:]
        assert proto == 0x07
        port = struct.unpack(">H", rhs)[0]

        offset, proto, lhs, rhs = unpack_floor(view)
        view = view[offset:]
        assert proto == 0x09
        addr = struct.unpack(">I", rhs)[0]

        towers.append(
            Tower(
                service=service,
                data_rep=data_rep,
                protocol=protocol,
                port=port,
                addr=addr,
            )
        )
        view = view[padding:]

    assert len(towers) == num_towers

    return towers


def rpc_get_key(
    dc: str,
    ctx: spnego.ContextProxy,
    target_sd: bytes,
    root_key_id: t.Optional[uuid.UUID],
    l0: int,
    l1: int,
    l2: int,
    auth_level: str = "privacy_integrity",
    auth_protocol: str = "negotiate",
) -> GroupKeyEnvelope:
    sign_header = False
    if auth_level == "privacy_integrity":
        sign_header = True
        auth_level = "privacy"

    auth_level = {
        "none": 0,  # RPC_C_AUTHN_LEVEL_DEFAULT
        "integrity": 5,  # RPC_C_AUTHN_LEVEL_PKT_INTEGRITY
        "privacy": 6,  # RPC_C_AUTHN_LEVEL_PKT_PRIVACY
    }[auth_level]
    auth_type = {
        "negotiate": 0x09,  # RPC_C_AUTHN_GSS_NEGOTIATE
        "ntlm": 0x0A,  # RPC_C_AUTHN_WINNT
        "kerberos": 0x10,  # RPC_C_AUTHN_GSS_KERBEROS
    }[auth_protocol]
    bind_syntaxes = [
        NDR[0].bytes_le + struct.pack("<H", NDR[1]) + struct.pack("<H", NDR[2]),
        NDR64[0].bytes_le + struct.pack("<H", NDR64[1]) + struct.pack("<H", NDR64[2]),
        BIND_TIME_FEATURE_NEGOTIATION[0].bytes_le
        + struct.pack("<H", BIND_TIME_FEATURE_NEGOTIATION[1])
        + struct.pack("<H", BIND_TIME_FEATURE_NEGOTIATION[2]),
    ]

    # Find the dynamic endpoint port for the ISD service.
    with socket.create_connection((dc, 135)) as s:
        bind_data = create_bind(
            EMP,
            bind_syntaxes,
            sign_header=False,
        )
        s.sendall(bind_data)
        resp = s.recv(4096)
        parse_bind_ack(resp)

        opnum, map_request = create_ept_map_request(ISD_KEY, NDR)
        request = create_request(
            opnum,
            map_request,
        )
        s.sendall(request)
        resp = s.recv(4096)

        ept_response = parse_response(resp)
        isd_towers = parse_ept_map_response(ept_response)
        assert len(isd_towers) > 0
        isd_port = isd_towers[0].port

    out_token = ctx.step()
    assert out_token

    with socket.create_connection((dc, isd_port)) as s:
        bind_data = create_bind(
            ISD_KEY,
            bind_syntaxes,
            auth_data=out_token,
            sign_header=sign_header,
            auth_level=auth_level,
            auth_type=auth_type,
        )

        s.sendall(bind_data)
        resp = s.recv(4096)
        in_token = parse_bind_ack(resp)

        out_token = ctx.step(in_token)

        while out_token:
            # For odd numbered legs we should be sending the MS-RPCE rpc_auth3
            # PDU. This doesn't seem to be necessary in testing though and
            # doing so just adds more complication, use the alter context PDU
            # like normal.
            out_data = create_alter_context(
                ISD_KEY,
                out_token,
                sign_header=sign_header,
                auth_level=auth_level,
                auth_type=auth_type,
            )
            s.sendall(out_data)
            resp = s.recv(4096)
            in_token = parse_alter_context(resp)

            if in_token:
                out_token = ctx.step(in_token)
            else:
                out_token = None

        assert ctx.complete
        assert not out_token

        get_key_req = GetKeyRequest(target_sd, root_key_id, l0, l1, l2)
        request = create_request(
            get_key_req.opnum,
            get_key_req.pack(),
            ctx=ctx,
            sign_header=sign_header,
            auth_level=auth_level,
            auth_type=auth_type,
        )
        s.sendall(request)
        resp = s.recv(4096)

        create_key_resp = parse_response(resp, ctx=ctx, sign_header=sign_header)
        return GetKeyRequest.unpack_response(create_key_resp)


@pytest.mark.parametrize('service', ['HTTP', 'cifs'])
def test_kerberos_authentication(service, monkeypatch):
    if os.name != 'nt':
        monkeypatch.setenv('KRB5_KTNAME', '/etc/%s.keytab' % service)

    else:
        if not IS_SYSTEM:
            pytest.skip("Cannot run Kerberos server tests when not running as SYSTEM")

    c = spnego.client(USERNAME, PASSWORD, hostname=HOST_FQDN, service=service)
    s = spnego.server()

    in_token = None
    while not c.complete:
        out_token = c.step(in_token)
        if not out_token:
            break

        in_token = s.step(out_token)

    if os.name == 'nt':
        # SSPI will report the user in the Netlogon form, use PowerShell to convert the UPN to the Netlogon user.
        pwsh_command = base64.b64encode('''
$username = '{0}'
([Security.Principal.NTAccount]$username).Translate(
    [Security.Principal.SecurityIdentifier]
).Translate([Security.Principal.NTAccount]).Value
'''.format(USERNAME).encode('utf-16-le')).decode()
        expected_username = subprocess.Popen(['powershell', '-EncodedCommand', pwsh_command],
                                             stdout=subprocess.PIPE).communicate()[0].strip().decode()

    else:
        expected_username = USERNAME

    assert s.client_principal == expected_username
    assert s.session_key == c.session_key
    assert c.negotiated_protocol == 'kerberos'
    assert s.negotiated_protocol == 'kerberos'

    plaintext = os.urandom(16)
    c_iov_res = c.wrap_iov([spnego.iov.BufferType.header, plaintext, spnego.iov.BufferType.padding])
    assert c_iov_res.encrypted
    assert len(c_iov_res.buffers) == 3
    assert c_iov_res.buffers[1].data != plaintext

    s_iov_res = s.unwrap_iov(c_iov_res.buffers)
    assert s_iov_res.encrypted
    assert s_iov_res.qop == 0
    assert len(s_iov_res.buffers) == 3
    assert s_iov_res.buffers[1].data == plaintext

    plaintext = os.urandom(16)
    s_iov_res = s.wrap_iov([spnego.iov.BufferType.header, plaintext, spnego.iov.BufferType.padding])
    assert s_iov_res.encrypted
    assert len(s_iov_res.buffers) == 3
    assert s_iov_res.buffers[1].data != plaintext

    c_iov_res = c.unwrap_iov(s_iov_res.buffers)
    assert c_iov_res.encrypted
    assert c_iov_res.qop == 0
    assert len(c_iov_res.buffers) == 3
    assert c_iov_res.buffers[1].data == plaintext


@pytest.mark.skipif(os.name == 'nt', reason='Windows doesnt support keytabs')
@pytest.mark.parametrize('protocol, set_principal', [
    ('kerberos', False),
    ('kerberos', True),
    ('negotiate', False),
    ('negotiate', True),
])
def test_kerberos_authentication_with_keytab(protocol, set_principal, monkeypatch):
    monkeypatch.setenv('KRB5_KTNAME', '/etc/HTTP.keytab')

    auth_kwargs = {}
    if protocol != 'negotiate':
        auth_kwargs['protocol'] = protocol

    if set_principal:
        keytab = spnego.KerberosKeytab(keytab=os.path.expanduser('~/user.keytab'), principal=USERNAME)
    else:
        keytab = spnego.KerberosKeytab(keytab=os.path.expanduser('~/user.keytab'))

    c = spnego.client(keytab, hostname=HOST_FQDN, service='HTTP', **auth_kwargs)
    s = spnego.server(**auth_kwargs)

    in_token = None
    while not c.complete:
        out_token = c.step(in_token)
        if not out_token:
            break

        in_token = s.step(out_token)

    assert s.client_principal == USERNAME
    assert s.session_key == c.session_key
    assert c.negotiated_protocol == 'kerberos'
    assert s.negotiated_protocol == 'kerberos'


@pytest.mark.parametrize('protocol, use_cache_cred', [
    ('kerberos', False),
    ('kerberos', True),
    ('negotiate', False),
    ('negotiate', True),
])
@pytest.mark.skipif(os.name == 'nt', reason='Due to how its run this doesnt really work for Windows')
def test_kerberos_authentication_with_cache(protocol, use_cache_cred, monkeypatch, tmp_path):
    cred = None

    monkeypatch.setenv('KRB5_KTNAME', '/etc/HTTP.keytab')

    kinit_args = ['kinit']
    config = {}
    if KERBEROS_PROVIDER == 'Heimdal':
        kinit_args.append('--password-file=STDIN')

    kinit_args.append(USERNAME)

    tmp_ccache = tmp_path / 'test_ccache'
    open(tmp_ccache, mode='wb').close()
    kinit_env = os.environ.copy()
    kinit_env['KRB5CCNAME'] = tmp_ccache

    if use_cache_cred:
        cred = spnego.KerberosCCache(ccache=tmp_ccache)

    else:
        monkeypatch.setenv('KRB5CCNAME', str(tmp_ccache))
        kinit_env['KRB5CCNAME'] = tmp_ccache

    process = subprocess.Popen(kinit_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                env=kinit_env)
    stdout, stderr = process.communicate(PASSWORD.encode() + b'\n')
    if process.returncode != 0:
        raise Exception("Failed to get Kerberos credential %d: %s %s" % (process.returncode, stdout, stderr))

    auth_kwargs = {}
    if protocol != 'negotiate':
        auth_kwargs['protocol'] = protocol

    c = spnego.client(cred, hostname=HOST_FQDN, service='HTTP', **auth_kwargs)
    s = spnego.server(**auth_kwargs)

    in_token = None
    while not c.complete:
        out_token = c.step(in_token)
        if not out_token:
            break

        in_token = s.step(out_token)

    assert s.client_principal == USERNAME
    assert s.session_key == c.session_key
    assert c.negotiated_protocol == 'kerberos'
    assert s.negotiated_protocol == 'kerberos'


@pytest.mark.skipif(KERBEROS_PROVIDER != 'MIT', reason="Cannot test on Windows and Heimdal does not work")
def test_winrm_rc4_wrapping(monkeypatch):
    context_kwargs = {
        'hostname': WIN_SERVER_TRUSTED,
        'service': 'HTTP',
    }
    if os.name != 'nt' or IS_SYSTEM:
        context_kwargs['username'] = USERNAME
        context_kwargs['password'] = PASSWORD

    with tempfile.NamedTemporaryFile() as temp_cfg:
        if KERBEROS_PROVIDER == 'Heimdal':
            temp_cfg.write(b'''[libdefaults]
default_etypes = arcfour-hmac-md5''')

        else:
            temp_cfg.write(b'''[libdefaults]
default_tkt_enctypes = arcfour-hmac-md5
default_tgs_enctypes = arcfour-hmac-md5''')

        temp_cfg.flush()

        monkeypatch.setenv('KRB5_CONFIG', '%s:/etc/krb5.conf' % temp_cfg.name)
        c = spnego.client(**context_kwargs)
        winrm_run(c, 'Negotiate', WIN_SERVER_TRUSTED, 'hostname.exe')


def test_winrm_no_delegate():
    context_kwargs = {
        'hostname': WIN_SERVER_TRUSTED,
        'service': 'HTTP',
    }
    if os.name != 'nt' or IS_SYSTEM:
        context_kwargs['username'] = USERNAME
        context_kwargs['password'] = PASSWORD

    c = spnego.client(**context_kwargs)

    pwsh_script = base64.b64encode('''$ErrorActionPreference = 'Stop'
(Get-Item -LiteralPath \\\\{0}\\c$\\temp\\test_integration.py).FullName
'''.format(WIN_SERVER_UNTRUSTED).encode('utf-16-le')).decode()
    rc, stdout, stderr = winrm_run(c, 'Negotiate', WIN_SERVER_TRUSTED, 'powershell.exe',
                                   ['-EncodedCommand', pwsh_script])

    # Without the delegate flag being set in the context we will fail to access any network paths
    assert rc == 1
    assert stdout == ''
    assert 'Access is denied' in stderr


def test_winrm_delegate():
    context_kwargs = {
        'hostname': WIN_SERVER_TRUSTED,
        'service': 'HTTP',
        'context_req': spnego.ContextReq.default | spnego.ContextReq.delegate,
    }
    if os.name != 'nt' or IS_SYSTEM:
        context_kwargs['username'] = USERNAME
        context_kwargs['password'] = PASSWORD

    c = spnego.client(**context_kwargs)

    pwsh_script = base64.b64encode('''$ErrorActionPreference = 'Stop'
(Get-Item -LiteralPath \\\\{0}\\c$\\temp\\test_integration.py).FullName
'''.format(WIN_SERVER_UNTRUSTED).encode('utf-16-le')).decode()
    rc, stdout, stderr = winrm_run(c, 'Negotiate', WIN_SERVER_TRUSTED, 'powershell.exe',
                                   ['-EncodedCommand', pwsh_script])

    assert rc == 0
    assert stdout.lower().strip() == '\\\\{0}\\c$\\temp\\test_integration.py'.format(WIN_SERVER_UNTRUSTED).lower()
    assert stderr == ''


@pytest.mark.skipif(os.name == 'nt', reason='Windows always sets ok as delegate so nothing to test there')
def test_winrm_ok_as_delegate():
    # In the environment setup, this host is not trusted for delegation in AD. This test shows that this flag isn't
    # used for GSSAPI unless set by the krb5.conf file.
    context_kwargs = {
        'hostname': WIN_SERVER_UNTRUSTED,
        'service': 'HTTP',
        'context_req': spnego.ContextReq.default | spnego.ContextReq.delegate,
        'username': USERNAME,
        'password': PASSWORD,
    }

    c = spnego.client(**context_kwargs)

    pwsh_script = base64.b64encode('''$ErrorActionPreference = 'Stop'
(Get-Item -LiteralPath \\\\{0}\\c$\\temp).FullName
'''.format(WIN_SERVER_TRUSTED).encode('utf-16-le')).decode()
    rc, stdout, stderr = winrm_run(c, 'Negotiate', WIN_SERVER_UNTRUSTED, 'powershell.exe',
                                   ['-EncodedCommand', pwsh_script])

    assert rc == 0
    assert stdout.lower().strip() == '\\\\{0}\\c$\\temp'.format(WIN_SERVER_TRUSTED).lower()
    assert stderr == ''


@pytest.mark.parametrize('option', [
    spnego.NegotiateOptions.none,
    spnego.NegotiateOptions.use_negotiate,
])
def test_winrm_cbt(option):
    cbt = get_cbt_data(WIN_SERVER_TRUSTED)

    protocol = 'negotiate'
    if os.name == 'nt':
        if option == spnego.NegotiateOptions.use_negotiate:
            pytest.skip("The Python negotiate provider is not for use in Windows")

    elif KERBEROS_PROVIDER == 'MIT':
        # There's a bug in MIT Negotiate where it doesn't pass along the channel bindings. This has been fixed but
        # is not present in any released version yet so just use Kerberos for this test.
        protocol = 'kerberos'

    c = spnego.client(USERNAME, PASSWORD, hostname=WIN_SERVER_TRUSTED, service='HTTP', channel_bindings=cbt,
                      protocol=protocol, options=option)
    winrm_run(c, 'Negotiate', WIN_SERVER_TRUSTED, 'hostname.exe', ssl=True)


# FUTURE: Enable TLSv1.3 tests once Windows supports it
@pytest.mark.parametrize('target, protocol, tls_protocol', [
    (WIN_SERVER_TRUSTED, 'kerberos', None),
    (WIN_SERVER_TRUSTED, 'kerberos', "TLSv1.2"),
    # (WIN_SERVER_TRUSTED, 'kerberos', "TLSv1.3"),
    # Connecting with an IP breaks Kerberos authentication.
    (WIN_SERVER_TRUSTED_IP, 'ntlm', None),
    (WIN_SERVER_TRUSTED_IP, 'ntlm', "TLSv1.2"),
    # (WIN_SERVER_TRUSTED_IP, 'ntlm', "TLSv1.3"),
])
def test_winrm_credssp(target, protocol, tls_protocol):
    client_kwargs = {}
    if tls_protocol:
        tls_context = spnego.tls.default_tls_context()

        default_opt = ssl.Options.OP_NO_SSLv2 | \
            ssl.Options.OP_NO_SSLv3 | \
            ssl.Options.OP_NO_TLSv1 | \
            ssl.Options.OP_NO_TLSv1_1 | \
            ssl.Options.OP_NO_TLSv1_2 | \
            ssl.Options.OP_NO_TLSv1_3

        try:
            ver = ssl.TLSVersion.TLSv1_2 if tls_protocol == "TLSv1.2" else ssl.TLSVersion.TLSv1_3
            tls_context.context.minimum_version = ver
            tls_context.context.maximum_version = ver
        except (ValueError, AttributeError):
            remove_opt = ssl.Options.OP_NO_TLSv1_2 if tls_protocol == "TLSv1.2" else ssl.Options.OP_NO_TLSv1_3
            tls_context.context.options |= (default_opt & ~remove_opt)

        client_kwargs["credssp_tls_context"] = tls_context

    c = spnego.client(USERNAME, PASSWORD, hostname=target, service='HTTP', protocol='credssp', **client_kwargs)

    pwsh_script = base64.b64encode(u'''$ErrorActionPreference = 'Stop'
(Get-Item -LiteralPath \\\\{0}\\c$\\temp).FullName
'''.format(WIN_SERVER_UNTRUSTED).encode('utf-16-le')).decode()
    rc, stdout, stderr = winrm_run(c, 'CredSSP', target, 'powershell.exe', ['-EncodedCommand', pwsh_script])

    assert rc == 0
    assert stdout.lower().strip() == '\\\\{0}\\c$\\temp'.format(WIN_SERVER_UNTRUSTED).lower()
    assert stderr == ''
    assert c.negotiated_protocol == protocol

    # Python 3.6 on Windows seems to always report it as TLSv1.0 but on packet inspection it really is correct.
    # Just ignoring that particular version with this check.
    if tls_protocol and (os.name != 'nt' or sys.version_info > (3,7)):
        assert c._tls_object.cipher()[1] == tls_protocol


def test_smb_auth():
    s = socket.create_connection((WIN_SERVER_TRUSTED, 445), timeout=10)
    mid = 0
    try:
        negotiate_response = smb_negotiate_request(s, mid)
        mid += 1

        if os.name != 'nt' or IS_SYSTEM:
            c = spnego.client(USERNAME, PASSWORD, hostname=WIN_SERVER_TRUSTED, service='cifs')

        else:
            c = spnego.client(hostname=WIN_SERVER_TRUSTED, service='cifs')

        out_token = c.step(negotiate_response.data.buffer)

        session_response = smb_session_setup(s, mid, out_token)
        mid += 1
        out_token = c.step(session_response.data.buffer)

        assert out_token is None
        assert c.complete
        assert c.negotiated_protocol == 'kerberos'
        assert isinstance(c.session_key, bytes)

        # SMB2 LOGOFF
        smb_send(s, mid, 0x0002, b"\x04\x00\x00\x00", session_id=session_response.session_id)
        mid += 1

    finally:
        s.close()


@pytest.mark.parametrize('protocol', ["kerberos", "negotiate", "negotiate-ntlm", "ntlm"])
def test_ldap_no_integrity(protocol: str) -> None:
    context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
    context.load_default_certs()

    # Connecting over LDAPS will allow us to test out no_integrity works as
    # MS LDAP doesn't allow integrity/confidentiality over an existing LDAPS
    # connection.

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) as sock:
        sock.settimeout(10)

        with context.wrap_socket(sock, server_hostname=WIN_DC) as ssock:
            ssock.connect((WIN_DC, 636))

            auth_kwargs = {
                "service": "ldap",
                "context_req": spnego.ContextReq.mutual_auth | spnego.ContextReq.no_integrity,
            }
            if protocol == "negotiate-ntlm":
                auth_kwargs["hostname"] = WIN_DC_IP
                auth_kwargs["protocol"] = "negotiate"
            else:
                auth_kwargs["hostname"] = WIN_DC
                auth_kwargs["protocol"] = protocol

            if os.name != 'nt' or IS_SYSTEM:
                c = spnego.client(
                    USERNAME,
                    PASSWORD,
                    **auth_kwargs,
                )

            else:
                c = spnego.client(**auth_kwargs)

            ldap = sansldap.LDAPClient()

            in_token = None
            while not c.complete:
                token = c.step(in_token=in_token)
                if not token:
                    break

                ldap.bind_sasl("GSS-SPNEGO", cred=token)
                ssock.write(ldap.data_to_send())

                msg = ldap.receive(ssock.recv(4096))[0]
                assert isinstance(msg, sansldap.BindResponse)
                if msg.result.result_code not in [
                    sansldap.LDAPResultCode.SUCCESS,
                    sansldap.LDAPResultCode.SASL_BIND_IN_PROGRESS
                ]:
                    raise Exception(f"Failed to bind: {msg.result.diagnostics_message}")

                in_token = msg.server_sasl_creds

            if msg.result.result_code != sansldap.LDAPResultCode.SUCCESS:
                raise Exception(f"Failed to bind {msg.result.result_code.name}: {msg.result.diagnostics_message}")

            # LDAP Whoami
            ldap.extended_request("1.3.6.1.4.1.4203.1.11.3")
            ssock.write(ldap.data_to_send())

            msg = ldap.receive(ssock.recv(4096))[0]
            assert isinstance(msg, sansldap.ExtendedResponse)
            assert msg.value.startswith(b"u:SPNEGO\\")


@pytest.mark.parametrize(
    'protocol, auth_level',
    [
        ("kerberos", "privacy_integrity"),
        ("kerberos", "privacy"),
        ("negotiate", "privacy_integrity"),
        ("negotiate", "privacy"),
        ("negotiate-ntlm", "privacy_integrity"),
        ("negotiate-ntlm", "privacy"),
        ("ntlm", "privacy_integrity"),
        ("ntlm", "privacy"),
    ],
)
def test_dce_rpc_auth(protocol: str, auth_level: str) -> None:
    auth_kwargs = {
        "service": "host",
        "context_req": spnego.ContextReq.default | spnego.ContextReq.dce_style
    }
    if protocol == "negotiate-ntlm":
        auth_kwargs["hostname"] = WIN_DC_IP
        auth_kwargs["protocol"] = "negotiate"
    else:
        auth_kwargs["hostname"] = WIN_DC
        auth_kwargs["protocol"] = protocol

    if protocol == 'ntlm':
        auth_kwargs['username'] = USERNAME
        auth_kwargs['password'] = PASSWORD
        auth_kwargs['options'] = spnego.NegotiateOptions.use_ntlm
    elif os.name != 'nt' or IS_SYSTEM:
        auth_kwargs['username'] = USERNAME
        auth_kwargs['password'] = PASSWORD

    c = spnego.client(**auth_kwargs)

    # O:SYG:SYD:(A;;CCDC;;;WD)(A;;DC;;;WD)
    target_sd = (
        b"\x01\x00\x04\x80\x44\x00\x00\x00"
        b"\x50\x00\x00\x00\x00\x00\x00\x00"
        b"\x14\x00\x00\x00\x02\x00\x30\x00"
        b"\x02\x00\x00\x00\x00\x00\x14\x00"
        b"\x03\x00\x00\x00\x01\x01\x00\x00"
        b"\x00\x00\x00\x01\x00\x00\x00\x00"
        b"\x00\x00\x14\x00\x02\x00\x00\x00"
        b"\x01\x01\x00\x00\x00\x00\x00\x01"
        b"\x00\x00\x00\x00\x01\x01\x00\x00"
        b"\x00\x00\x00\x05\x12\x00\x00\x00"
        b"\x01\x01\x00\x00\x00\x00\x00\x05"
        b"\x12\x00\x00\x00"
    )

    res = rpc_get_key(
        WIN_DC,
        c,
        target_sd,
        root_key_id=None,
        l0=-1,
        l1=-1,
        l2=-1,
        auth_level=auth_level,
        auth_protocol=auth_kwargs["protocol"],
    )
    assert isinstance(res, GroupKeyEnvelope)
