#!/usr/bin/env python
# vim:fileencoding=utf-8
# License: Apache 2.0 Copyright: 2017, Kovid Goyal <kovid at kovidgoyal.net>

from __future__ import absolute_import, division, print_function, unicode_literals

import codecs
import os
import re
import subprocess
import unittest

from html5_parser import check_bom, check_for_meta_charset, parse
from lxml.etree import _Comment

from . import MATHML, SVG, XHTML, XLINK, XML, TestCase

self_path = os.path.abspath(__file__)
base = os.path.dirname(self_path)
html5lib_tests_path = os.path.join(base, 'html5lib-tests')


class TestData(object):

    def __init__(self, filename):
        with open(filename, 'rb') as f:
            raw = f.read()
        encoding = 'utf-8'
        if '/encoding/' in filename.replace(os.sep, '/'):
            encoding = 'utf-8' if os.path.basename(filename) == 'test-yahoo-jp.dat' else 'latin1'
        self.lines = raw.decode(encoding).splitlines()

    def __iter__(self):
        data = {}
        key = None
        for line in self.lines:
            heading = self.is_section_heading(line)
            if heading:
                if data and heading == 'data':
                    data[key] = data[key][:-1]
                    yield self.normalize(data)
                    data = {}
                key = heading
                data[key] = ''
            elif key is not None:
                data[key] += line + '\n'
        if data:
            yield self.normalize(data)

    def is_section_heading(self, line):
        """If the current heading is a test section heading return the heading,
        otherwise return False"""
        if line.startswith("#"):
            return line[1:].strip()
        else:
            return False

    def normalize(self, data):

        def n(x):
            if x.endswith('\n'):
                x = x[:-1]
            return x

        return {k: n(v) for k, v in data.items()}


def serialize_construction_output(root, fragment_context):
    tree = root.getroottree()
    lines = []
    if tree.docinfo.doctype and not fragment_context:
        di = tree.docinfo
        if di.public_id or di.system_url:
            d = '<!DOCTYPE {} "{}" "{}">'.format(di.root_name, di.public_id, di.system_url)
        else:
            d = '<!DOCTYPE {}>'.format(di.root_name)
        lines.append('| ' + d)

    NAMESPACE_PREFIXES = {XHTML: '', SVG: 'svg ', MATHML: 'math ', XLINK: 'xlink ', XML: 'xml '}

    def add(level, *a):
        lines.append('|' + ' ' * level + ''.join(a))

    def serialize_tag(name, level):
        ns = 'None '
        if name.startswith('{'):
            ns, name = name[1:].rpartition('}')[::2]
            ns = NAMESPACE_PREFIXES.get(ns, ns)
        add(level, '<', ns, name, '>')
        return ns + name

    def serialize_attr_name(name):
        ns = ''
        if name.startswith('{'):
            ns, name = name[1:].rpartition('}')[::2]
            ns = NAMESPACE_PREFIXES.get(ns, ns)
        return ns + name

    def serialize_attr(name, val, level):
        level += 2
        add(level, serialize_attr_name(name), '=', '"', val, '"')

    def serialize_text(text, level=0):
        add((level + 2) if level else 1, '"', text, '"')

    def serialize_comment(node, level=1):
        add(level, '<!-- ', node.text or '', ' -->')

    def serialize_node(node, level=1):
        name = serialize_tag(node.tag, level)
        for attr in sorted(node.keys(), key=serialize_attr_name):
            serialize_attr(attr, node.get(attr), level)
        if name == 'template':
            level += 2
            add(level, 'content')
        if node.text:
            serialize_text(node.text, level)
        for child in node:
            if isinstance(child, _Comment):
                serialize_comment(child, level + 2)
            else:
                serialize_node(child, level + 2)
            if child.tail:
                serialize_text(child.tail, level)

    if fragment_context:
        if root.text:
            serialize_text(root.text)
        for node in root.iterchildren():
            if isinstance(node, _Comment):
                serialize_comment(node)
            else:
                serialize_node(node)
            if node.tail:
                serialize_text(node.tail)
    else:
        for c in root.itersiblings(preceding=True):
            serialize_comment(c)
        serialize_node(root)
        for c in root.itersiblings():
            serialize_comment(c)
    output = '\n'.join(lines)
    # gumbo does not fix single carriage returns generated by entities and it
    # does not lowercase unknown tags
    output = output.replace('\r', '\n').replace('<BAR>', '<bar>')
    return output


class BaseTest(TestCase):

    @classmethod
    def data_for_test(cls, test, expected='document'):
        return test.get('document-fragment'), test.get('data'), test.get(expected), test.get(
            'errors', '').split('\n')

    @classmethod
    def add_single(cls, test_name, num, test, expected):
        inner_html, html, expected, errors = cls.data_for_test(test, expected)

        def test_func(
                self,
                inner_html=inner_html,
                html=html,
                expected=expected,
                errors=errors):
            return self.implementation(inner_html, html, expected, errors, test_name)

        test_func.__name__ = str('test_%s_%d' % (test_name, num))
        setattr(cls, test_func.__name__, test_func)


class ConstructionTests(BaseTest):

    @classmethod
    def check_test(cls, fragment_context, html, expected, errors, test_name):
        if test_name == 'isindex' or html == '<!doctype html><isindex type="hidden">':
            return (
                'gumbo and html5lib differ on <isindex> parsing'
                ' and I cannot be bothered to figure out who is right')
        if test_name == 'menuitem-element':
            return (
                'gumbo and html5lib differ on <menuitem> parsing'
                ' and I cannot be bothered to figure out who is right')
        if 'search-element' in test_name:
            return (
                'No idea what the <search> element is. In any case the tests only differ in'
                ' indentation, so skipping')
        noscript = re.search(r'^\| +<noscript>$', expected, flags=re.MULTILINE)
        if noscript is not None:
            return '<noscript> is always parsed with scripting off by gumbo'
        if '<thisisasillytestelementnametomakesurecrazytagnamesareparsedcorrectly>' in expected:
            return 'gumbo unlike html5lib, does not lowercase unknown tag names'
        for line in errors:
            if 'expected-doctype-name-but' in line or 'unknown-doctype' in line:
                return 'gumbo auto-corrects malformed doctypes'

    def implementation(self, fragment_context, html, expected, errors, test_name):
        if fragment_context:
            fragment_context = fragment_context.replace(' ', ':')
        bad = self.check_test(fragment_context, html, expected, errors, test_name)
        if bad is not None:
            raise unittest.SkipTest(bad)

        root = parse(
            html, namespace_elements=True, sanitize_names=False,
            fragment_context=fragment_context)
        output = serialize_construction_output(root, fragment_context=fragment_context)
        from lxml.etree import tostring

        error_msg = '\n'.join([
            '\n\nTest name:', test_name, '\nInput:', html, '\nExpected:', expected,
            '\nReceived:', output,
            '\nOutput tree:', tostring(root, encoding='unicode'),
        ])
        self.ae(expected, output, error_msg + '\n')
        # TODO: Check error messages, when there's full error support.


class EncodingTests(BaseTest):

    def implementation(self, fragment_context, html, expected, errors, test_name):
        if '<!-- Starts with UTF-8 BOM -->' in html:
            raw = b'\xef\xbb\xbf' + html[3:].encode('ascii')
            self.assertIs(check_bom(raw), codecs.BOM_UTF8)
            return
        if '''document.write('<meta charset="ISO-8859-' + '2">')''' in html:
            raise unittest.SkipTest('buggy html5lib test')
        raw = html.encode('utf-8')
        output = check_bom(raw) or check_for_meta_charset(raw) or 'windows-1252'
        error_msg = '\n'.join(
            map(type(''), ['\n\nInput:', html, '\nExpected:', expected, '\nReceived:', output]))
        self.ae(expected.lower(), output, error_msg + '\n')


def html5lib_test_files(group):
    if os.path.exists(html5lib_tests_path):
        base = os.path.join(html5lib_tests_path, group)
        for x in os.listdir(base):
            if x.endswith('.dat'):
                yield os.path.join(base, x)


def load_suite(group, case_class, expected='document', data_class=TestData):
    for path in html5lib_test_files(group):
        test_name = os.path.basename(path).rpartition('.')[0]
        for i, test in enumerate(data_class(path)):
            case_class.add_single(test_name, i + 1, test, expected)
    return unittest.defaultTestLoader.loadTestsFromTestCase(case_class)


class MemLeak(BaseTest):

    @unittest.skipUnless('MEMLEAK_EXE' in os.environ, 'memleak check exe not available')
    def test_asan_memleak(self):
        MEMLEAK_EXE = os.environ['MEMLEAK_EXE']
        env = os.environ.copy()
        env.pop('ASAN_OPTIONS', None)
        for path in html5lib_test_files('tree-construction'):
            test_name = os.path.basename(path).rpartition('.')[0]
            for i, test in enumerate(TestData(path)):
                inner_html, html, expected, errors = ConstructionTests.data_for_test(test)
                bad = ConstructionTests.check_test(inner_html, html, expected, errors, test_name)
                if bad is not None:
                    continue
                p = subprocess.Popen([MEMLEAK_EXE], stdin=subprocess.PIPE, env=env)
                p.communicate(html.encode('utf-8'))
                self.ae(p.wait(), 0, 'The test {}-{} failed'.format(test_name, i))


def find_tests():
    yield load_suite('tree-construction', ConstructionTests)
    yield load_suite('encoding', EncodingTests, expected='encoding')
    yield unittest.defaultTestLoader.loadTestsFromTestCase(MemLeak)
