# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 Jun Omae <jun66j5@gmail.com>
# All rights reserved.
#
# This software is licensed as described in the file COPYING, which
# you should have received as part of this distribution.

import os
import shutil
import tempfile
import unittest

from trac.test import EnvironmentStub, MockRequest
from trac.util import create_file
from trac.web.api import RequestDone
from trac.web.auth import LoginModule as TracLoginModule
from trac.web.chrome import Chrome
from trac.web.main import RequestDispatcher
from trac.wiki.web_ui import WikiModule

from ..api import AccountManager
from ..svnserve import SvnServePasswordStore
from ..web_ui import LoginModule
from . import makeSuite


class LoginTestCase(unittest.TestCase):

    authz_content = '[users]\n' \
                    'john = pass\n'
    referer = 'http://example.org/trac.cgi/wiki/WikiStart'

    def setUp(self):
        self.tmpdir = tempfile.mkdtemp(prefix='trac-testdir-')
        self.env = EnvironmentStub(
            path=self.tmpdir,
            enable=[Chrome, WikiModule, AccountManager, SvnServePasswordStore,
                    LoginModule],
            disable=[TracLoginModule])
        self.config = self.env.config
        self.config.set('trac', 'use_chunked_encoding', 'disabled')
        authz_file = os.path.join(self.tmpdir, 'svnserve-authz.txt')
        create_file(authz_file, self.authz_content)
        self.config.set('account-manager', 'password_store',
                        'SvnServePasswordStore')
        self.config.set('account-manager', 'password_file', authz_file)
        self.mod = LoginModule(self.env)
        self.dispatcher = RequestDispatcher(self.env)

    def tearDown(self):
        self.env.shutdown()
        shutil.rmtree(self.tmpdir)

    def _create_req(self, **kwargs):
        referer = kwargs.pop('referer', None)
        req = MockRequest(self.env, **kwargs)
        req.environ.pop('REMOTE_USER', None)
        if referer:
            req.environ['HTTP_REFERER'] = referer
        req.callbacks['authname'] = self.dispatcher.authenticate
        return req

    def _create_req_login(self, username='', password='', rememberme=None,
                          referer=None, cookie=''):
        args = {'username': username, 'password': password}
        if rememberme is not None:
            args['rememberme'] = rememberme
        return self._create_req(method='POST', path_info='/login', args=args,
                                referer=referer, cookie=cookie)

    def _to_incookie(self, req):
        return '; '.join('{}={}'.format(key, c.value)
                         for key, c in req.outcookie.items())

    def test_login_success(self):
        req = self._create_req_login(username='john', password='pass',
                                     referer=self.referer)
        self.assertRaises(RequestDone, self.dispatcher.dispatch, req)
        self.assertEqual(self.referer, req.headers_sent.get('Location'))
        self.assertIn('trac_auth', req.outcookie)
        self.assertTrue(req.outcookie['trac_auth'].value)
        self.assertFalse(req.outcookie['trac_auth'].get('expires'))
        self.assertNotIn('trac_auth_session', req.outcookie)

        cookie = self._to_incookie(req)
        req = self._create_req(path_info='/wiki/WikiStart', cookie=cookie)
        self.assertRaises(RequestDone, self.dispatcher.dispatch, req)
        for item in req.chrome['nav']['metanav']:
            if item['name'] == 'login':
                self.assertIn('>john<', str(item['label']))
                break
        else:
            self.fail('Missing login item in metanav')

    def test_login_failure(self):
        req = self._create_req_login(username='john', password='????')
        self.assertRaises(RequestDone, self.dispatcher.dispatch, req)
        self.assertIn(b'Invalid username or password',
                      req.response_sent.getvalue())

    def test_login_persist(self):
        self.config.set('account-manager', 'persistent_sessions', 'enabled')
        self.config.set('account-manager', 'cookie_refresh_pct', '-100')
        req = self._create_req_login(username='john', password='pass',
                                     rememberme='1', referer=self.referer)
        self.assertRaises(RequestDone, self.dispatcher.dispatch, req)
        self.assertIn('trac_auth', req.outcookie)
        self.assertTrue(req.outcookie['trac_auth'].get('expires'))
        self.assertIn('trac_auth_session', req.outcookie)
        self.assertEqual('1', req.outcookie['trac_auth_session'].value)
        self.assertTrue(req.outcookie['trac_auth_session'].get('expires'))
        old_session_id = req.outcookie['trac_auth'].value
        cookie = self._to_incookie(req)

        self.config.set('account-manager', 'cookie_refresh_pct', '100')
        req = self._create_req(path_info='/wiki/WikiStart', cookie=cookie)
        self.assertRaises(RequestDone, self.dispatcher.dispatch, req)
        self.assertIn('trac_auth', req.outcookie)
        new_session_id = req.outcookie['trac_auth'].value
        self.assertNotEqual(old_session_id, new_session_id)
        self.assertEqual(len(old_session_id), len(new_session_id))


def test_suite():
    suite = unittest.TestSuite()
    suite.addTest(makeSuite(LoginTestCase))
    return suite


if __name__ == '__main__':
    unittest.main(defaultTest='test_suite')
