# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import jwt
import pytest
from azure.messaging.webpubsubservice import WebPubSubServiceClient
from azure.core.credentials import AzureKeyCredential

try:
    from urlparse import urlparse
except ImportError:
    from urllib.parse import urlparse


def _decode_token(client, token, path="/client/hubs/hub"):
    return jwt.decode(
        token, client._config.credential.key, algorithms=["HS256"], audience=client._config.endpoint + path
    )


access_key = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGH"

test_cases = [
    ("Endpoint=https://host;AccessKey={};Version=1.0;".format(access_key), "https://host"),
    ("Endpoint=http://host;AccessKey={};Version=1.0;".format(access_key), "http://host"),
    ("Endpoint=http://host;AccessKey={};Version=1.0;Port=8080;".format(access_key), "http://host:8080"),
    ("AccessKey={};Endpoint=http://host;Version=1.0;".format(access_key), "http://host"),
]


@pytest.mark.parametrize("connection_string,endpoint", test_cases)
def test_parse_connection_string(connection_string, endpoint):
    client = WebPubSubServiceClient.from_connection_string(connection_string, "hub")
    assert client._config.endpoint == endpoint
    assert isinstance(client._config.credential, AzureKeyCredential)
    assert client._config.credential.key == access_key


test_cases = [
    (None, None, None),
    ("ab", [], []),
    ("ab", ["a"], ["a"]),
    ("ab", ["a", "a", "a"], ["a", "a", "a"]),
    ("ab", ["a", "b", "c"], ["a", "b", "c"]),
    ("ab", "", ""),
]


@pytest.mark.parametrize("user_id,roles,groups", test_cases)
def test_generate_uri_contains_expected_payloads_dto(user_id, roles, groups):
    client = WebPubSubServiceClient.from_connection_string(
        "Endpoint=http://localhost;Port=8080;AccessKey={};Version=1.0;".format(access_key), "hub"
    )
    minutes_to_expire = 5
    token = client.get_client_access_token(
        user_id=user_id, roles=roles, minutes_to_expire=minutes_to_expire, groups=groups
    )
    assert token
    assert len(token) == 3
    assert set(token.keys()) == set(["baseUrl", "url", "token"])
    assert "access_token={}".format(token["token"]) == urlparse(token["url"]).query
    token = token["token"]
    decoded_token = _decode_token(client, token)
    assert decoded_token["aud"] == "{}/client/hubs/hub".format(client._config.endpoint)

    # default expire should be around 5 minutes
    assert decoded_token["exp"] - decoded_token["iat"] >= minutes_to_expire * 60 - 5
    assert decoded_token["exp"] - decoded_token["iat"] <= minutes_to_expire * 60 + 5
    if user_id:
        assert decoded_token["sub"] == user_id
    else:
        assert not decoded_token.get("sub")

    if roles:
        assert decoded_token["role"] == roles
    else:
        assert not decoded_token.get("role")

    if groups:
        assert decoded_token["webpubsub.group"] == groups
    else:
        assert not decoded_token.get("webpubsub.group")


test_cases = [
    (
        "Endpoint=http://localhost;Port=8080;AccessKey={};Version=1.0;".format(access_key),
        "hub",
        "ws://localhost:8080/client/hubs/hub",
    ),
    ("Endpoint=https://a;AccessKey={};Version=1.0;".format(access_key), "hub", "wss://a/client/hubs/hub"),
    ("Endpoint=http://a;AccessKey={};Version=1.0;".format(access_key), "hub", "ws://a/client/hubs/hub"),
]


@pytest.mark.parametrize("connection_string,hub,expected_url", test_cases)
def test_generate_url_use_same_kid_with_same_key(connection_string, hub, expected_url):
    client = WebPubSubServiceClient.from_connection_string(connection_string, hub)
    url_1 = client.get_client_access_token()["url"]
    url_2 = client.get_client_access_token()["url"]

    assert url_1.split("?")[0] == url_2.split("?")[0] == expected_url

    token_1 = urlparse(url_1).query[len("access_token=") :]
    token_2 = urlparse(url_2).query[len("access_token=") :]

    decoded_token_1 = _decode_token(client, token_1)
    decoded_token_2 = _decode_token(client, token_2)

    assert len(decoded_token_1) == len(decoded_token_2) == 3
    assert decoded_token_1["aud"] == decoded_token_2["aud"] == expected_url.replace("ws", "http")
    assert abs(decoded_token_1["iat"] - decoded_token_2["iat"]) < 5
    assert abs(decoded_token_1["exp"] - decoded_token_2["exp"]) < 5


test_cases = [
    ("Endpoint=http://localhost;Port=8080;AccessKey={};Version=1.0;".format(access_key)),
    ("Endpoint=https://a;AccessKey={};Version=1.0;".format(access_key)),
    ("Endpoint=http://a;AccessKey={};Version=1.0;".format(access_key)),
]


@pytest.mark.parametrize("connection_string", test_cases)
def test_pass_in_jwt_headers(connection_string):
    client = WebPubSubServiceClient.from_connection_string(connection_string, "hub")
    kid = "1234567890"
    token = client.get_client_access_token(jwt_headers={"kid": kid})["token"]
    assert jwt.get_unverified_header(token)["kid"] == kid


test_cases = [
    (
        "Endpoint=http://localhost;Port=8080;AccessKey={};Version=1.0;".format(access_key),
        "hub",
        "ws://localhost:8080/clients/mqtt/hubs/hub",
    ),
    ("Endpoint=https://a;AccessKey={};Version=1.0;".format(access_key), "hub", "wss://a/clients/mqtt/hubs/hub"),
    ("Endpoint=http://a;AccessKey={};Version=1.0;".format(access_key), "hub", "ws://a/clients/mqtt/hubs/hub"),
]


@pytest.mark.parametrize("connection_string,hub,expected_url", test_cases)
def test_generate_mqtt_token(connection_string, hub, expected_url):
    client = WebPubSubServiceClient.from_connection_string(connection_string, hub)
    url_1 = client.get_client_access_token(client_protocol="MQTT")["url"]

    assert url_1.split("?")[0] == expected_url

    token_1 = urlparse(url_1).query[len("access_token=") :]

    decoded_token_1 = _decode_token(client, token_1, path="/clients/mqtt/hubs/hub")

    assert len(decoded_token_1) == 3
    assert decoded_token_1["aud"] == expected_url.replace("ws", "http")


test_cases = [
    (
        "Endpoint=http://localhost;Port=8080;AccessKey={};Version=1.0;".format(access_key),
        "hub",
        "ws://localhost:8080/clients/socketio/hubs/hub",
    ),
    ("Endpoint=https://a;AccessKey={};Version=1.0;".format(access_key), "hub", "wss://a/clients/socketio/hubs/hub"),
    ("Endpoint=http://a;AccessKey={};Version=1.0;".format(access_key), "hub", "ws://a/clients/socketio/hubs/hub"),
]


@pytest.mark.parametrize("connection_string,hub,expected_url", test_cases)
def test_generate_socketio_token(connection_string, hub, expected_url):
    client = WebPubSubServiceClient.from_connection_string(connection_string, hub)
    url_1 = client.get_client_access_token(client_protocol="SocketIO")["url"]

    assert url_1.split("?")[0] == expected_url

    token_1 = urlparse(url_1).query[len("access_token=") :]

    decoded_token_1 = _decode_token(client, token_1, path="/clients/socketio/hubs/hub")

    assert len(decoded_token_1) == 3
    assert decoded_token_1["aud"] == expected_url.replace("ws", "http")
