"""
Copyright (c) 2023 Proton AG

This file is part of Proton.

Proton is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
from dataclasses import dataclass
from typing import Awaitable, List
import aiohttp
from ..exceptions import *
from .aiohttp import AiohttpTransport

import json, base64, struct, time, asyncio, random, itertools

from urllib.parse import urlparse

from ..api import sync_wrapper
from .utils.dns import DNSParser, DNSResponseError


@dataclass
class AlternativeRoutingDNSQueryAnswer:
    """Contains the result of a successful DNS query to retrieve the
    alternative routing server domain."""
    expiration_time: float
    domain: str


class AlternativeRoutingTransport(AiohttpTransport):
    DNS_PROVIDERS = [
        #dns.google
        (("8.8.4.4", "8.8.8.8"), ("2001:4860:4860::8844", "2001:4860:4860::8888"), '/dns-query'),
        #dns11.quad9.net
        (("149.112.112.11", "9.9.9.11"), ("2620:fe::fe:11", "2620:fe::11"), '/dns-query'),
    ]

    STRUCT_REPLY_COUNTS = struct.Struct('>HHHH')
    STRUCT_REC_FORMAT = struct.Struct('>HHIH')

    #Delay between DNS requests
    DELAY_DNS_REQUEST = 2
    TIMEOUT_DNS_REQUEST = 10

    @classmethod
    def _get_priority(cls):
        return 5


    def __init__(self, session):
        super().__init__(session)
        self._alternative_routes = []

    @classmethod
    def _compute_ar_domain(cls, host):
        return b'd' + base64.b32encode(host.encode('ascii')).strip(b'=') + b".protonpro.xyz"

    async def _async_dns_query(
            self, domain, dns_server_ip, dns_server_path, delay=0
    ) -> List[AlternativeRoutingDNSQueryAnswer]:
        import aiohttp

        if delay > 0:
            await asyncio.sleep(delay)

        ardomain = self._compute_ar_domain(domain)
        dns_request = DNSParser.build_query(ardomain, qtype=16, qclass=1)  # TXT IN
        dot_url = f'https://{dns_server_ip}{dns_server_path}'

        async with aiohttp.ClientSession() as session:
            async with session.post(dot_url, headers=[("Content-Type","application/dns-message")], data=dns_request) as r:
                reply_data = await r.content.read()

        try:
            dns_answers = DNSParser.parse(reply_data)
        except DNSResponseError as e:
            raise ProtonAPINotReachable(str(e))

        now = time.time()
        # Tuples (TTL, data)
        answers = []
        for rec_ttl, rec_val in dns_answers:
            answers.append(AlternativeRoutingDNSQueryAnswer(
                expiration_time=now + rec_ttl,
                domain=rec_val)
            )

        return answers

    @property
    def _http_domain(self):
        return urlparse(super().http_base_url).netloc

    async def _get_alternative_routes(self):
        # We generate a random list of dns servers, 
        # we query them following that order, simultaneoulsy on IPv4/IPv6
        choices_ipv4 = []
        choices_ipv6 = []
        for dns_server_ipv4s, dns_server_ipv6s, dns_server_path in self.DNS_PROVIDERS:
            for ip in dns_server_ipv4s:
                choices_ipv4.append((ip, dns_server_path))
            for ip in dns_server_ipv6s:
                choices_ipv6.append((ip, dns_server_path))

        random.shuffle(choices_ipv4)
        random.shuffle(choices_ipv6)

        pending = []
        i = 0
        for ipv4, ipv6 in itertools.zip_longest(choices_ipv4, choices_ipv6, fillvalue=None):
            if i * self.DELAY_DNS_REQUEST > self.TIMEOUT_DNS_REQUEST:
                break

            if ipv4 is not None:
                pending.append(asyncio.create_task(self._async_dns_query(self._http_domain, ipv4[0], ipv4[1], delay=i * self.DELAY_DNS_REQUEST)))
            if ipv6 is not None:
                pending.append(asyncio.create_task(self._async_dns_query(self._http_domain, f'[{ipv6[0]}]', ipv6[1], delay=i * self.DELAY_DNS_REQUEST)))
            
            i += 1

        results_ok = []
        results_fail = []
        proton_api_not_available_errors = []
        final_timestamp = time.time() + self.TIMEOUT_DNS_REQUEST
        while len(pending) > 0 and len(results_ok) == 0:
            done, pending = await asyncio.wait(pending, timeout=max(0.1, final_timestamp - time.time()), return_when=asyncio.FIRST_COMPLETED)
            for task in done:
                try:
                    results_ok += task.result()
                except ProtonAPINotAvailable as e:
                    # That means that we were able to do a resolution, but it explicitly failed
                    proton_api_not_available_errors.append(e)
                except Exception as e:
                    results_fail.append(e)
        
        for task in pending:
            task.cancel()

        if proton_api_not_available_errors:
            raise proton_api_not_available_errors[0]

        if len(results_ok) == 0:
            if len(self._alternative_routes) > 0:
                # We have routes, but we were not able to resolve new ones. Just keep the old ones
                return
            else:
                # No routes, and failed to get new ones
                raise ProtonAPINotReachable("Couldn't resolve any alternative routing names")

        domains = [x.domain for x in results_ok]
        # Filter names that are in our results (we don't want duplicates)
        self._alternative_routes = [
            x for x in self._alternative_routes
            if x.domain not in domains and x.expiration_time >= time.time()
        ]
        # Add the results
        self._alternative_routes += results_ok
        # Sort them so we have the most recent on top
        self._alternative_routes.sort(key=lambda x: x.expiration_time, reverse=True)

    @property
    def http_base_url(self):
        if len(self._alternative_routes) == 0:
            raise ProtonAPINotReachable("AlternativeRouting transport doesn't have any route")

        path = urlparse(super().http_base_url).path
        
        return f'https://{self._alternative_routes[0].domain}{path}'

    @property
    def tls_pinning_hashes(self):
        return self._environment.tls_pinning_hashes_ar

    async def async_api_request(
        self, endpoint,
        jsondata=None, data=None, additional_headers=None,
        method=None, params=None, return_raw=False
    ):
        if len(self._alternative_routes) == 0 or self._alternative_routes[0].expiration_time < time.time():
            await self._get_alternative_routes()

        return await super().async_api_request(endpoint, jsondata, data, additional_headers, method,
                                               params, return_raw=return_raw)
