from flask import json
from authlib.common.security import generate_token
from authlib.common.urls import urlparse, url_decode
from authlib.oauth2.rfc6749 import grants
from authlib.oauth2.rfc7636 import (
    CodeChallenge as _CodeChallenge,
    create_s256_code_challenge,
)
from .models import db, User, Client
from .models import CodeGrantMixin, save_authorization_code
from .oauth2_server import TestCase
from .oauth2_server import create_authorization_server


class AuthorizationCodeGrant(CodeGrantMixin, grants.AuthorizationCodeGrant):
    TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none']

    def save_authorization_code(self, code, request):
        return save_authorization_code(code, request)


class CodeChallenge(_CodeChallenge):
    SUPPORTED_CODE_CHALLENGE_METHOD = ['plain', 'S256', 'S128']


class CodeChallengeTest(TestCase):
    def prepare_data(self, token_endpoint_auth_method='none'):
        server = create_authorization_server(self.app)
        server.register_grant(
            AuthorizationCodeGrant,
            [CodeChallenge(required=True)]
        )

        user = User(username='foo')
        db.session.add(user)
        db.session.commit()

        client_secret = ''
        if token_endpoint_auth_method != 'none':
            client_secret = 'code-secret'

        client = Client(
            user_id=user.id,
            client_id='code-client',
            client_secret=client_secret,
        )
        client.set_client_metadata({
            'redirect_uris': ['https://a.b'],
            'scope': 'profile address',
            'token_endpoint_auth_method': token_endpoint_auth_method,
            'response_types': ['code'],
            'grant_types': ['authorization_code'],
        })
        self.authorize_url = (
            '/oauth/authorize?response_type=code'
            '&client_id=code-client'
        )
        db.session.add(client)
        db.session.commit()

    def test_missing_code_challenge(self):
        self.prepare_data()
        rv = self.client.get(self.authorize_url + '&code_challenge_method=plain')
        self.assertIn(b'Missing', rv.data)

    def test_has_code_challenge(self):
        self.prepare_data()
        rv = self.client.get(self.authorize_url + '&code_challenge=abc')
        self.assertEqual(rv.data, b'ok')

    def test_invalid_code_challenge_method(self):
        self.prepare_data()
        suffix = '&code_challenge=abc&code_challenge_method=invalid'
        rv = self.client.get(self.authorize_url + suffix)
        self.assertIn(b'Unsupported', rv.data)

    def test_supported_code_challenge_method(self):
        self.prepare_data()
        suffix = '&code_challenge=abc&code_challenge_method=plain'
        rv = self.client.get(self.authorize_url + suffix)
        self.assertEqual(rv.data, b'ok')

    def test_trusted_client_without_code_challenge(self):
        self.prepare_data('client_secret_basic')
        rv = self.client.get(self.authorize_url)
        self.assertEqual(rv.data, b'ok')

        rv = self.client.post(self.authorize_url, data={'user_id': '1'})
        self.assertIn('code=', rv.location)

        params = dict(url_decode(urlparse.urlparse(rv.location).query))

        code = params['code']
        headers = self.create_basic_header('code-client', 'code-secret')
        rv = self.client.post('/oauth/token', data={
            'grant_type': 'authorization_code',
            'code': code,
        }, headers=headers)
        resp = json.loads(rv.data)
        self.assertIn('access_token', resp)

    def test_missing_code_verifier(self):
        self.prepare_data()
        url = self.authorize_url + '&code_challenge=foo'
        rv = self.client.post(url, data={'user_id': '1'})
        self.assertIn('code=', rv.location)

        params = dict(url_decode(urlparse.urlparse(rv.location).query))
        code = params['code']
        rv = self.client.post('/oauth/token', data={
            'grant_type': 'authorization_code',
            'code': code,
            'client_id': 'code-client',
        })
        resp = json.loads(rv.data)
        self.assertIn('Missing', resp['error_description'])

    def test_trusted_client_missing_code_verifier(self):
        self.prepare_data('client_secret_basic')
        url = self.authorize_url + '&code_challenge=foo'
        rv = self.client.post(url, data={'user_id': '1'})
        self.assertIn('code=', rv.location)

        params = dict(url_decode(urlparse.urlparse(rv.location).query))
        code = params['code']
        headers = self.create_basic_header('code-client', 'code-secret')
        rv = self.client.post('/oauth/token', data={
            'grant_type': 'authorization_code',
            'code': code,
        }, headers=headers)
        resp = json.loads(rv.data)
        self.assertIn('Missing', resp['error_description'])

    def test_plain_code_challenge_invalid(self):
        self.prepare_data()
        url = self.authorize_url + '&code_challenge=foo'
        rv = self.client.post(url, data={'user_id': '1'})
        self.assertIn('code=', rv.location)

        params = dict(url_decode(urlparse.urlparse(rv.location).query))
        code = params['code']
        rv = self.client.post('/oauth/token', data={
            'grant_type': 'authorization_code',
            'code': code,
            'code_verifier': 'bar',
            'client_id': 'code-client',
        })
        resp = json.loads(rv.data)
        self.assertIn('Invalid', resp['error_description'])

    def test_plain_code_challenge_failed(self):
        self.prepare_data()
        url = self.authorize_url + '&code_challenge=foo'
        rv = self.client.post(url, data={'user_id': '1'})
        self.assertIn('code=', rv.location)

        params = dict(url_decode(urlparse.urlparse(rv.location).query))
        code = params['code']
        rv = self.client.post('/oauth/token', data={
            'grant_type': 'authorization_code',
            'code': code,
            'code_verifier': generate_token(48),
            'client_id': 'code-client',
        })
        resp = json.loads(rv.data)
        self.assertIn('failed', resp['error_description'])

    def test_plain_code_challenge_success(self):
        self.prepare_data()
        code_verifier = generate_token(48)
        url = self.authorize_url + '&code_challenge=' + code_verifier
        rv = self.client.post(url, data={'user_id': '1'})
        self.assertIn('code=', rv.location)

        params = dict(url_decode(urlparse.urlparse(rv.location).query))
        code = params['code']
        rv = self.client.post('/oauth/token', data={
            'grant_type': 'authorization_code',
            'code': code,
            'code_verifier': code_verifier,
            'client_id': 'code-client',
        })
        resp = json.loads(rv.data)
        self.assertIn('access_token', resp)

    def test_s256_code_challenge_success(self):
        self.prepare_data()
        code_verifier = generate_token(48)
        code_challenge = create_s256_code_challenge(code_verifier)
        url = self.authorize_url + '&code_challenge=' + code_challenge
        url += '&code_challenge_method=S256'

        rv = self.client.post(url, data={'user_id': '1'})
        self.assertIn('code=', rv.location)

        params = dict(url_decode(urlparse.urlparse(rv.location).query))
        code = params['code']
        rv = self.client.post('/oauth/token', data={
            'grant_type': 'authorization_code',
            'code': code,
            'code_verifier': code_verifier,
            'client_id': 'code-client',
        })
        resp = json.loads(rv.data)
        self.assertIn('access_token', resp)

    def test_not_implemented_code_challenge_method(self):
        self.prepare_data()
        url = self.authorize_url + '&code_challenge=foo'
        url += '&code_challenge_method=S128'

        rv = self.client.post(url, data={'user_id': '1'})
        self.assertIn('code=', rv.location)

        params = dict(url_decode(urlparse.urlparse(rv.location).query))
        code = params['code']
        self.assertRaises(
            RuntimeError, self.client.post, '/oauth/token',
            data={
                'grant_type': 'authorization_code',
                'code': code,
                'code_verifier': generate_token(48),
                'client_id': 'code-client',
            }
        )
