# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations

import enum
import struct

from bumble import core, crypto, device, gatt, gatt_client

# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16


class SirkType(enum.IntEnum):
    '''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''

    ENCRYPTED = 0x00
    PLAINTEXT = 0x01


class MemberLock(enum.IntEnum):
    '''Coordinated Set Identification Service - 5.3 Set Member Lock.'''

    UNLOCKED = 0x01
    LOCKED = 0x02


# -----------------------------------------------------------------------------
# Crypto Toolbox
# -----------------------------------------------------------------------------
def s1(m: bytes) -> bytes:
    '''
    Coordinated Set Identification Service - 4.3 s1 SALT generation function.
    '''
    return crypto.aes_cmac(m[::-1], bytes(16))[::-1]


def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
    '''
    Coordinated Set Identification Service - 4.4 k1 derivation function.
    '''
    t = crypto.aes_cmac(n[::-1], salt[::-1])
    return crypto.aes_cmac(p[::-1], t)[::-1]


def sef(k: bytes, r: bytes) -> bytes:
    '''
    Coordinated Set Identification Service - 4.5 SIRK encryption function sef.

    SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
      * Plaintext in encryption
      * Cipher in decryption
    '''
    return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)


def sih(k: bytes, r: bytes) -> bytes:
    '''
    Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
    '''
    return crypto.e(k, r + bytes(13))[:3]


def generate_rsi(sirk: bytes) -> bytes:
    '''
    Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
    '''
    prand = crypto.generate_prand()
    return sih(sirk, prand) + prand


# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationService(gatt.TemplateService):
    UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE

    set_identity_resolving_key: bytes
    set_identity_resolving_key_characteristic: gatt.Characteristic[bytes]
    coordinated_set_size_characteristic: gatt.Characteristic[bytes] | None = None
    set_member_lock_characteristic: gatt.Characteristic[bytes] | None = None
    set_member_rank_characteristic: gatt.Characteristic[bytes] | None = None

    def __init__(
        self,
        set_identity_resolving_key: bytes,
        set_identity_resolving_key_type: SirkType,
        coordinated_set_size: int | None = None,
        set_member_lock: MemberLock | None = None,
        set_member_rank: int | None = None,
    ) -> None:
        if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
            raise core.InvalidArgumentError(
                f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
            )

        characteristics = []

        self.set_identity_resolving_key = set_identity_resolving_key
        self.set_identity_resolving_key_type = set_identity_resolving_key_type
        self.set_identity_resolving_key_characteristic = gatt.Characteristic(
            uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
            properties=gatt.Characteristic.Properties.READ
            | gatt.Characteristic.Properties.NOTIFY,
            permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
            value=gatt.CharacteristicValue(read=self.on_sirk_read),
        )
        characteristics.append(self.set_identity_resolving_key_characteristic)

        if coordinated_set_size is not None:
            self.coordinated_set_size_characteristic = gatt.Characteristic(
                uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
                properties=gatt.Characteristic.Properties.READ
                | gatt.Characteristic.Properties.NOTIFY,
                permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
                value=struct.pack('B', coordinated_set_size),
            )
            characteristics.append(self.coordinated_set_size_characteristic)

        if set_member_lock is not None:
            self.set_member_lock_characteristic = gatt.Characteristic(
                uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
                properties=gatt.Characteristic.Properties.READ
                | gatt.Characteristic.Properties.NOTIFY
                | gatt.Characteristic.Properties.WRITE,
                permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
                | gatt.Characteristic.Permissions.WRITEABLE,
                value=struct.pack('B', set_member_lock),
            )
            characteristics.append(self.set_member_lock_characteristic)

        if set_member_rank is not None:
            self.set_member_rank_characteristic = gatt.Characteristic(
                uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
                properties=gatt.Characteristic.Properties.READ
                | gatt.Characteristic.Properties.NOTIFY,
                permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
                value=struct.pack('B', set_member_rank),
            )
            characteristics.append(self.set_member_rank_characteristic)

        super().__init__(characteristics)

    async def on_sirk_read(self, connection: device.Connection) -> bytes:
        if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
            sirk_bytes = self.set_identity_resolving_key
        else:
            if connection.transport == core.PhysicalTransport.LE:
                key = await connection.device.get_long_term_key(
                    connection_handle=connection.handle, rand=b'', ediv=0
                )
            else:
                key = await connection.device.get_link_key(connection.peer_address)

            if not key:
                raise core.InvalidOperationError('LTK or LinkKey is not present')

            sirk_bytes = sef(key, self.set_identity_resolving_key)

        return bytes([self.set_identity_resolving_key_type]) + sirk_bytes

    def get_advertising_data(self) -> bytes:
        return bytes(
            core.AdvertisingData(
                [
                    (
                        core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
                        generate_rsi(self.set_identity_resolving_key),
                    ),
                ]
            )
        )


# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
    SERVICE_CLASS = CoordinatedSetIdentificationService

    set_identity_resolving_key: gatt_client.CharacteristicProxy[bytes]
    coordinated_set_size: gatt_client.CharacteristicProxy[bytes] | None = None
    set_member_lock: gatt_client.CharacteristicProxy[bytes] | None = None
    set_member_rank: gatt_client.CharacteristicProxy[bytes] | None = None

    def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
        self.service_proxy = service_proxy

        self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
            gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
        )[0]

        if characteristics := service_proxy.get_characteristics_by_uuid(
            gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
        ):
            self.coordinated_set_size = characteristics[0]

        if characteristics := service_proxy.get_characteristics_by_uuid(
            gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
        ):
            self.set_member_lock = characteristics[0]

        if characteristics := service_proxy.get_characteristics_by_uuid(
            gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
        ):
            self.set_member_rank = characteristics[0]

    async def read_set_identity_resolving_key(self) -> tuple[SirkType, bytes]:
        '''Reads SIRK and decrypts if encrypted.'''
        response = await self.set_identity_resolving_key.read_value()
        if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
            raise core.InvalidPacketError('Invalid SIRK value')

        sirk_type = SirkType(response[0])
        if sirk_type == SirkType.PLAINTEXT:
            sirk = response[1:]
        else:
            connection = self.service_proxy.client.connection
            device = connection.device
            if connection.transport == core.PhysicalTransport.LE:
                key = await device.get_long_term_key(
                    connection_handle=connection.handle, rand=b'', ediv=0
                )
            else:
                key = await device.get_link_key(connection.peer_address)

            if not key:
                raise core.InvalidOperationError('LTK or LinkKey is not present')

            sirk = sef(key, response[1:])

        return (sirk_type, sirk)
