#!/usr/bin/env python
"""Generates the pamqp/specification.py file used as a foundation for AMQP
communication.

"""
import copy
import dataclasses
import functools
import json
import keyword
import logging
import pathlib
import sys
import textwrap
import typing

import lxml.etree, lxml.objectify
import requests
from yapf import yapf_api

__author__ = 'Gavin M. Roy'
__email__ = 'gavinmroy@gmail.com'
__since__ = '2011-03-31'

LOGGER = logging.getLogger(__name__)

CODEGEN_DIR = pathlib.Path('./codegen/')
CODEGEN_IGNORE_CLASSES = ['access']
CODEGEN_JSON = CODEGEN_DIR / 'amqp-rabbitmq-0.9.1.json'
CODEGEN_XML = CODEGEN_DIR / 'amqp0-9-1.xml'
EXTENSIONS_XML = CODEGEN_DIR / 'extensions.xml'

COMMANDS = pathlib.Path('./pamqp/commands.py')
COMMANDS_HEADER = '''"""
The classes inside :mod:`pamqp.commands` allow for the automatic marshaling
and unmarshaling of AMQP method frames and
:class:`Basic.Properties <pamqp.commands.Basic.Properties>`. In addition the
command classes contain information that designates if they are synchronous
commands and if so, what the expected responses are. Each commands arguments
are detailed in the class and are listed in the attributes property.

.. note:: All AMQ classes and methods extend :class:`pamqp.base.Frame`.

"""
# Auto-generated, do not edit this file.
import datetime
import typing
import warnings

from pamqp import base, common, constants

'''

CONSTANTS = pathlib.Path('./pamqp/constants.py')
CONSTANTS_HEADER = '''
# Auto-generated, do not edit this file.
import re
'''

EXCEPTIONS = pathlib.Path('./pamqp/exceptions.py')
EXCEPTIONS_HEADER = '''
# Auto-generated, do not edit this file.


class PAMQPException(Exception):
    """Base exception for all pamqp specific exceptions."""


class UnmarshalingException(PAMQPException):
    """Raised when a frame is not able to be unmarshaled."""
    def __str__(self) -> str:  # pragma: nocover
        return 'Could not unmarshal {} frame: {}'.format(
            self.args[0], self.args[1])


class AMQPError(PAMQPException):
    """Base exception for all AMQP errors."""


class AMQPSoftError(AMQPError):
    """Base exception for all AMQP soft errors."""


class AMQPHardError(AMQPError):
    """Base exception for all AMQP hard errors."""

'''

CODEGEN_JSON_URL = ('https://raw.githubusercontent.com/rabbitmq/'
                    'rabbitmq-codegen/master/amqp-rabbitmq-0.9.1.json')
CODEGEN_XML_URL = ('https://raw.githubusercontent.com/rabbitmq/'
                   'rabbitmq-website/master/site/resources/specs/'
                   'amqp0-9-1.extended.xml')

XPATH_ORDER = ['class', 'constant', 'method', 'field']

AMQ_TYPE_TO_ANNOTATION = {
    'bit': 'bool',
    'long': 'int',
    'longlong': 'int',
    'longstr': 'str',
    'octet': 'int',
    'short': 'int',
    'shortstr': 'str',
    'table': 'common.FieldTable',
    'timestamp': 'common.Timestamp',
}


@dataclasses.dataclass
class Domain:
    name: str
    type: str
    documentation: typing.Optional[str]
    label: typing.Optional[str]
    nullable: bool
    regex: typing.Optional[str]
    max_length: typing.Optional[int]
    default_value: typing.Optional[str]


class Codegen:

    AMQP_VERSION = '0-9-1'
    DEPRECATION_WARNING = 'Deprecated'
    YAPF = True

    def __init__(self):
        self._amqp_json = self._load_codegen_json()
        self._codegen_xml = self._load_codegen_xml()
        self._extensions_xml = self._load_extensions_xml()
        self._output_buffer: typing.List[str] = []
        self.AMQP_VERSION = ('-'.join([
            str(self._amqp_json['major-version']),
            str(self._amqp_json['minor-version']),
            str(self._amqp_json['revision'])]))
        self.DEPRECATION_WARNING = \
            'This command is deprecated in AMQP {}'.format(self.AMQP_VERSION)

    def build(self):
        self._build_commands()
        self._build_constants()
        self._build_exceptions()

    def _add_comment(self, value: str, indent: int = 0,
                     prefix: str = '# ') -> None:
        """Append a comment to the output buffer"""
        if not value:
            return self._add_line()
        indent_text = prefix.rjust(indent + len(prefix))
        wrapped = textwrap.wrap(value, 79 - len(indent_text))
        if value.startswith(':') or value.startswith('..'):
            self._add_line(indent_text + wrapped[0], width=79)
            indent += 4
            indent_text = prefix.rjust(indent + len(prefix))
            value = value[len(wrapped[0]) + 1:]
            for line in  textwrap.wrap(value, 79 - len(indent_text)):
                self._add_line(indent_text + line, width=79)
        else:
            for line in wrapped:
                self._add_line(indent_text + line, width=79)

    def _add_documentation(self, label: str, documentation: str,
                           indent: int) -> None:
        if not documentation:
            self._add_line('"""{}"""'.format(label), indent)
        else:
            for key, value in sorted(self._commands().items(), reverse=True):
                documentation = documentation.replace(
                    ' {} '.format(key), ' :class:`{}` '.format(value))
            self._add_line('"""{}'.format(label), indent)
            self._add_line()
            for line in documentation.split('\n'):
                self._add_comment(line, indent, '')
            self._add_line()
            self._add_line('"""', indent)

    def _add_function(self, name: str, args: list, indent: int) -> None:
        """Create a new function"""
        self._add_line()
        if not len(args):
            self._add_line('def {}(self) -> None:'.format(name), indent)
        else:
            self._add_line('def {}(self,'.format(name), indent)
            indent += len('def {}('.format(name))
            for index, arg in enumerate(args):
                annotation = self._arg_annotation(arg)
                default = self._arg_default(arg)
                if arg['name'] == 'arguments' and default == '{}':
                    default = 'None'
                if default == 'None' or default is None:
                    annotation = 'typing.Optional[{}]'.format(annotation)
                if index == len(args) - 1:
                    self._add_line('{}: {} = {}) -> None:'.format(
                        self._arg_name(arg['pyname']),
                        annotation, default), indent)
                else:
                    self._add_line('{}: {} = {},'.format(
                        self._arg_name(arg['pyname']),
                        annotation, default), indent)

    def _add_line(self, value: str = '', indent: int = 0,
                  secondary_indent: int = 0,
                  width: typing.Optional[int] = None) -> None:
        """Append a new line to the output buffer"""
        if not value:
            self._output_buffer.append(value)
            return
        initial = ''.rjust(indent)
        secondary = ''.rjust(secondary_indent or indent)
        wrapper = textwrap.TextWrapper(
            width=width or 120 if self.YAPF else 79,
            drop_whitespace=True,
            initial_indent=initial,
            break_on_hyphens=False,
            subsequent_indent=secondary)

        for value in wrapper.wrap(value.rstrip()):
            self._output_buffer.append(value)

    def _all_defaults(self, arguments: list) -> bool:
        args = []
        for arg in arguments:
            arg_type = AMQ_TYPE_TO_ANNOTATION[self._arg_type(arg)]
            if arg_type in {'common.FieldTable', 'common.Arguments'}:
                args.append(True)
            else:
                args.append(self._arg_default(arg) is not None)
        return all(args)

    def _arg_annotation(self, arg: dict) -> str:
        arg_type = AMQ_TYPE_TO_ANNOTATION[self._arg_type(arg)]
        if arg_type == 'common.FieldTable':
            if arg['name'] == 'arguments':
                return 'common.Arguments'
        elif arg_type == 'common.Timestamp':
            return 'datetime.datetime'
        return arg_type

    def _arg_default(self, arg: dict, repr_=True) -> typing.Optional[str]:
        ext = self._extensions_xml.xpath(
            '//rabbitmq/field[@name="{}"]'.format(arg['name']))
        if ext and ext[0].attrib.get('default-value') is not None:
            if repr_ and arg['type'] != 'short' and arg['type'] != 'bit':
                return '{!r}'.format(ext[0].attrib.get('default-value'))
            return ext[0].attrib.get('default-value')
        elif arg.get('default-value') is not None:
            if repr_ and arg['type'] != 'short' and arg['type'] != 'bit':
                return '{!r}'.format(arg['default-value'])
            return arg['default-value']
        elif 'domain' in arg:
            domain = self._domain(arg['domain'])
            if domain and domain.name == arg['domain'] \
                    and domain.default_value is not None:
                if repr_ and domain.type != 'short' and arg['type'] != 'bit':
                    return '{!r}'.format(domain.default_value)
                return domain.default_value
        return None

    @staticmethod
    def _arg_name(name: str) -> str:
        """Returns a valid python argument name for the AMQP argument"""
        value = name.replace('-', '_')
        if value in keyword.kwlist:
            LOGGER.debug('%s is in the keyword list', value)
            value += '_'
        return value

    def _arg_type(self, arg: dict) -> str:
        """Get the argument type"""
        if 'domain' in arg:
            for domain, data_type in self._amqp_json['domains']:
                if arg['domain'] == domain:
                    arg['type'] = data_type
                    break
        if 'type' in arg:
            return arg['type']
        raise ValueError('Unknown argument type')

    def _build_command_basic_properties(self, class_name: str,
                                        class_id: int,
                                        properties: list) -> None:
        for offset, arg in enumerate(properties):
            properties[offset]['pyname'] = self._arg_name(arg['name'])
            if properties[offset]['pyname'] == 'type':
                properties[offset]['pyname'] = 'message_type'

        indent = 4
        self._add_line('class Properties(base.BasicProperties):', indent)
        indent += 4
        self._add_comment('"""Content Properties', indent, '')
        self._add_line()
        self._add_line(
            '.. Note:: The AMQP property type is named ``message_type`` as to '
            'not conflict with the Python ``type`` keyword',
            indent, indent + 10, width=79)
        self._add_line()
        for arg in properties:
            label = self._label({'class': class_name, 'field': arg['name']})
            if label:
                line = ':param {}: {}'.format(arg['pyname'], label or '')
                self._add_line(line.strip(), indent, indent + 4, width=79)
            else:
                self._add_line(':param {}: Deprecated, must be empty'.format(
                    arg['pyname']), indent, indent + 4, width=79)
            if arg['name'] == 'headers':
                self._add_line(
                    ':type headers: typing.Optional['
                    ':const:`~pamqp.common.FieldTable`]',
                    indent, indent + 4, width=79)
        self._add_line(':raises: ValueError', indent)
        self._add_line()
        self._add_line('"""', indent)

        self._add_line('__annotations__: typing.Dict[str, object] = {',
                       indent)
        for offset, arg in enumerate(properties):
            if offset == len(properties) - 1:
                self._add_line(
                    '{!r}: {}'.format(arg['pyname'], self._arg_annotation(arg)),
                    indent + 4)
            else:
                self._add_line(
                    '{!r}: {},'.format(arg['pyname'], self._arg_annotation(arg)),
                    indent + 4)
        self._add_line('}', indent)

        self._add_line(
            '__slots__: typing.List[str] = [  # AMQ Properties Attributes',
            indent)
        for offset, arg in enumerate(properties):
            if offset == len(properties) - 1:
                self._add_line('{!r}'.format(arg['pyname']), indent + 4)
            else:
                self._add_line('{!r},'.format(arg['pyname']), indent + 4)
        self._add_line(']', indent)
        self._add_line()

        flag_value = 15
        self._add_comment('Flag values for marshaling / unmarshaling', indent)
        self._add_line(
            "flags = {{'{}': {},".format(
                properties[0]['pyname'], 1 << flag_value), indent)
        for arg in properties[1:-1]:
            flag_value -= 1
            self._add_line("'{}': {},".format(arg['pyname'], 1 << flag_value),
                           indent + 9),
        flag_value -= 1
        self._add_line("'{}': {}}}".format(
            properties[-1]['pyname'], 1 << flag_value),
            indent + 9)
        self._add_line()
        self._add_line('frame_id = {}  # AMQP Frame ID'.format(class_id),
                       indent)
        self._add_line(
            'index = 0x%04X  # pamqp Mapping Index' % class_id, indent)
        self._add_line("name = '{}.Properties'".format(
            self._pep8_class_name(class_name)), indent)
        self._add_line()

        self._add_comment('Class Attribute Types for unmarshaling', indent)
        for arg in properties:
            self._add_line("_{} = '{}'".format(
                arg['pyname'], self._arg_type(arg)), indent)

        self._add_function('__init__', properties, indent)
        indent += 4
        self._add_line('"""Initialize the {}.Properties class"""'.format(
            self._pep8_class_name(class_name)), indent)
        for arg in properties:
            self._add_line('self.{} = {}'.format(
                arg['pyname'], arg['pyname']), indent)
        self._add_line('self.validate()', indent)
        self._add_line()

    def _build_command_map(self) -> None:
        self._add_line()
        self._add_comment('AMQP Class.Method Index Mapping')
        self._add_line('INDEX_MAPPING = {')
        lines = []
        for amqp_class in self._amqp_json['classes']:
            if amqp_class['name'] not in CODEGEN_IGNORE_CLASSES:
                for method in amqp_class['methods']:
                    key = amqp_class['id'] << 16 | method['id']
                    lines.append(
                        '0x%08X: %s.%s,' % (
                            key,
                            self._pep8_class_name(amqp_class['name']),
                            self._pep8_class_name(method['name'])))
        lines[-1] = lines[-1][:-1]
        [self._add_line(line, 4) for line in lines]
        self._add_line('}')

    def _build_command_method(self, class_name: str, class_id: int,
                              method: dict) -> None:
        deprecated, needs_linefeed = False, True
        arguments = copy.deepcopy(method['arguments'])
        method_xml = self._search_xml(
            {'class': class_name, 'method': method['name']})
        indent = 4
        self._add_line('class {}(base.Frame):'.format(
            self._pep8_class_name(method['name'])), indent)
        documentation = []
        docs = self._documentation(
            {'class': class_name, 'method': method['name']})
        if docs is not None and docs.strip():
            documentation.append(docs.strip())
            needs_linefeed = True

        for arg in arguments:
            name = self._arg_name(arg['name'])
            if name == 'type' and class_name == 'exchange':
                documentation.append(
                    '\n.. note:: The AMQP type argument is referred to as '
                    '"{}_type" to not conflict with the Python type '
                    'keyword.'.format(class_name))
                needs_linefeed = True

        # Note the deprecation warning in the docblock
        if method_xml and method_xml[0].attrib.get('deprecated'):
            deprecated = True
            documentation.append(
                '\n.. deprecated:: {}'.format(self.DEPRECATION_WARNING))
            needs_linefeed = True

        if arguments and needs_linefeed:
            documentation.append('')

        for arg in arguments:
            arg_type = AMQ_TYPE_TO_ANNOTATION[self._arg_type(arg)]
            name = self._arg_name(arg['name'])
            default = self._arg_default(arg, False)
            if name == 'arguments' or arg_type in {'common.FieldTable',
                                                   'common.Arguments'}:
                default = '{}'
            domain = self._domain(arg.get('domain', arg['name']))
            if name == 'type' and class_name == 'exchange':
                name = 'exchange_type'
            search_path = {
                'class': class_name,
                'method': method['name'],
                'field': arg['name']}
            arg_doc = self._documentation(search_path)
            label = self._label(search_path) or arg_doc or ''
            if not label \
                    and domain and (domain.label or domain.documentation):
                label = domain.label or domain.documentation
            if default is not None:
                if default == '':
                    default = "''"
                if label:
                    label = '{}\n    - Default: ``{}``'.format(label, default)
                else:
                    label = 'Default: ``{}``'.format(default)
            documentation.append(':param {}: {}'.format(name,  label))
            if name == 'arguments':
                documentation.append(
                    ':type {}: :const:`~pamqp.common.Arguments`'.format(name))
            elif arg_type.startswith('common'):
                documentation.append(
                    ':type {}: :const:`~pamqp.{}`'.format(name, arg_type))

        raises = set()
        for arg in arguments:
            domain = self._domain(arg.get('domain', arg['name']))
            if domain and (domain.max_length or domain.regex):
                raises.add('ValueError')
        if raises:
            for exc_name in raises:
                documentation.append(
                    ':raises {}: when an argument fails to validate'.format(
                        exc_name))

        label = self._label(
            {'class': class_name, 'method': method['name']}) \
            or 'Undocumented function'

        indent = 8
        self._add_documentation(label, '\n'.join(documentation), indent)

        if not len(arguments):
            self._add_line('__annotations__: typing.Dict[str, object] = {}',
                           indent)
        else:
            self._add_line('__annotations__: typing.Dict[str, object] = {',
                           indent)
            for offset, arg in enumerate(arguments):
                name = self._arg_name(arg['name'])
                if name == 'type' and class_name == 'exchange':
                    name = 'exchange_type'
                arguments[offset]['pyname'] = name
                if offset == len(method['arguments']) - 1:
                    self._add_line(
                        '{!r}: {}'.format(name, self._arg_annotation(arg)),
                        indent + 4)
                else:
                    self._add_line(
                        '{!r}: {},'.format(name, self._arg_annotation(arg)),
                        indent + 4)
            self._add_line('}', indent)
        if not len(arguments):
            self._add_line(
                '__slots__: typing.List[str] = []  # AMQ Method Attributes',
                indent)
        else:
            self._add_line(
                '__slots__: typing.List[str] = [  # AMQ Method Attributes',
                indent)
            for offset, arg in enumerate(arguments):
                if offset == len(method['arguments']) - 1:
                    self._add_line('{!r}'.format(arg['pyname']), indent + 4)
                else:
                    self._add_line('{!r},'.format(arg['pyname']), indent + 4)
            self._add_line(']', indent)
        self._add_line()
        self._add_line('frame_id = %i  # AMQP Frame ID' % method['id'], indent)
        index_value = class_id << 16 | method['id']
        self._add_line(
            'index = 0x%08X  # pamqp Mapping Index' % index_value, indent)
        self._add_line("name = '{}.{}'".format(
            self._pep8_class_name(class_name),
            self._pep8_class_name(method['name'])), indent)
        self._add_line(
            'synchronous = {}  # Indicates if this is a synchronous AMQP '
            'method'.format(method.get('synchronous', False)), indent)

        if method.get('synchronous'):
            responses = []
            if method_xml:
                responses = [
                    "'{}.{}'".format(
                        self._pep8_class_name(class_name),
                        self._pep8_class_name(response.attrib['name']))
                    for response in method_xml[0].iter('response')]
            if not responses:
                responses = ["'{}.{}Ok'".format(
                    self._pep8_class_name(class_name),
                    self._pep8_class_name(method['name']))]
            line = 'valid_responses = [{}]'.format(', '.join(responses))
            if len(line) <= 36:
                self._add_line('{}  # Valid responses to this method'.format(
                    line), indent)
            else:
                self._add_comment('Valid responses to this method', indent)
                self._add_line(line, indent)

        if arguments:
            self._add_line()
            self._add_comment('Class Attribute Types for unmarshaling', indent)
        for arg in arguments:
            self._add_line("_{} = '{}'".format(arg['pyname'], arg['type']),
                           indent)
        if len(arguments):
            self._add_function('__init__', arguments, indent)
            indent += 4
            self._add_line(
                '"""Initialize the :class:`{}.{}` class"""'.format(
                    self._pep8_class_name(class_name),
                    self._pep8_class_name(method['name'])), indent)

            # Function
            all_defaults = self._all_defaults(arguments)
            for arg in arguments:
                domain = self._domain(arg.get('domain', arg['name']))
                default = self._arg_default(arg)
                if ((domain and not domain.nullable and domain.type == 'table')
                        or arg['name'] == 'arguments'
                        or (arg['type'] == 'table' and not default)):
                    self._add_line('self.{} = {} or {{}}'.format(
                        arg['pyname'], arg['pyname']), indent)
                elif default is not None \
                        and not all_defaults \
                        and default != '':
                    self._add_line('self.{} = {} or {}'.format(
                        arg['pyname'], arg['pyname'], default), indent)
                else:
                    self._add_line('self.{} = {}'.format(
                        arg['pyname'], arg['pyname']), indent)

            if deprecated:
                self._add_line('warnings.warn(', indent)
                self._add_line(
                    'constants.DEPRECATION_WARNING, '
                    'category=DeprecationWarning)', indent + 4)

            add_validate = False
            for arg in arguments:
                domain = self._domain(arg.get('domain', arg['name']))
                if ((arg['pyname'] in {'capabilities', 'channel_id',
                                       'cluster_id', 'known_hosts',
                                       'out_of_band', 'internal',
                                       'insist', 'ticket'})
                    or domain and domain.max_length is not None
                    or domain and domain.regex is not None):
                    add_validate = True
                    break

            if add_validate:
                self._add_line('self.validate()', indent)

            indent -=4

            def format_deprecated_value(
                    field_name: str,
                    operator: str,
                    expected_value: typing.Union[str, bool, int],
                    error_value: str) -> None:
                self._add_line(
                    "if self.{} is not None and self.{} {} {}:".format(
                        field_name, field_name, operator, expected_value),
                    indent)
                self._add_line(
                    "raise ValueError('{} must be {}')".format(
                        field_name, error_value), indent + 4)

            if add_validate:
                self._add_function('validate', [], indent)
                indent += 4
                self._add_comment(
                    '"""Validate the frame data ensuring all domains or '
                    'attributes adhere to the protocol specification.',
                    indent, '')
                self._add_line()
                self._add_line(':raises ValueError: on validation error', indent)
                self._add_line()
                self._add_line('"""', indent)
                for arg in arguments:
                    domain = self._domain(arg.get('domain', arg['name']))
                    if arg['pyname'] in {'capabilities', 'cluster_id',
                                         'known_hosts'}:
                        format_deprecated_value(
                            arg['pyname'], '!=', "''", 'empty')
                        continue
                    elif arg['name'] in {'internal', 'insist'}:
                        format_deprecated_value(arg['pyname'], 'is not',
                                                'False', 'False')
                        continue
                    elif arg['pyname'] == 'ticket':
                        format_deprecated_value(arg['pyname'], '!=', '0', '0')
                        continue
                    elif arg['pyname'] in {'channel_id', 'out_of_band'}:
                        format_deprecated_value(
                            arg['pyname'], '!=', "'0'", '0')
                        continue
                    if domain and domain.max_length is not None:
                        self._add_line(
                            'if self.{} is not None and '
                            'len(self.{}) > {}:'.format(
                                arg['pyname'], arg['pyname'],
                                domain.max_length),
                            indent)
                        line = ("raise ValueError('Max length "
                                "exceeded for {}')".format(arg['pyname']))
                        self._add_line(line, indent + 4)
                    if domain and domain.regex is not None:
                        self._add_line(
                            'if self.{} is not None and '
                            "not constants.DOMAIN_REGEX['{}']"
                            ".fullmatch(self.{}):".format(
                                arg['pyname'], domain.name, arg['pyname']),
                            indent)
                        line = 'raise ValueError(' \
                               "'Invalid value for {}')".format(arg['pyname'])
                        self._add_line(line, indent + 4)

        self._add_line()

    def _build_commands(self):
        LOGGER.info('Generating %s', COMMANDS)
        self._output_buffer = [COMMANDS_HEADER]
        for name in [c['name'] for c in self._amqp_json['classes']
                     if c['name'] not in CODEGEN_IGNORE_CLASSES]:
            indent = 4
            definition = self._class_definition(name)
            self._add_line()
            self._add_line('class %s:' % self._pep8_class_name(name))

            documentation = self._documentation({'class': name})
            label = self._label({'class': name}) or 'Undefined label'
            if documentation:
                self._add_documentation(label, documentation, indent)

            self._add_line('__slots__: typing.List[str] = []', indent)
            self._add_line()
            self._add_line('frame_id = {}  # AMQP Frame ID'.format(
                definition['id']), indent)
            self._add_line('index = 0x%08X  # pamqp Mapping Index' % (
                definition['id'] << 16), indent)
            self._add_line()

            for method_defn in definition['methods']:
                self._build_command_method(name, definition['id'], method_defn)

            if 'properties' in definition and definition['properties']:
                self._build_command_basic_properties(
                    name, definition['id'], definition['properties'])

        self._build_command_map()
        self._write_file(COMMANDS)

    def _build_constants(self):
        LOGGER.info('Generating %s', CONSTANTS)

        def _build_domain_output() -> typing.Tuple[typing.List[str],
                                                   typing.List[str],
                                                   typing.List[str]]:
            data_types, domains, dom_regex = [], [], []
            for domain in self._domains():
                if domain.name == domain.type:
                    value = "              '{}',".format(domain.name)
                    comment = '{} {}'.format('#'.rjust(28 - len(value)),
                                             domain.label)
                    data_types.append('{}{}'.format(value, comment))
                else:
                    domains.append("           '{}': '{}',".format(
                        domain.name, domain.type))
                    if domain.regex:
                        dom_regex.append(
                            "           '{}': re.compile(r'{}'),".format(
                                domain.name, domain.regex))

            data_types[0] = data_types[0].replace(
                '              ', 'DATA_TYPES = [')
            data_types[-1] = data_types[-1].replace(',', ' ')
            data_types.append(']')
            domains[0] = domains[0].replace('           ', 'DOMAINS = {')
            domains[-1] = domains[-1].replace(',', '}')

            dom_regex[0] = dom_regex[0].replace(
                '           ', 'DOMAIN_REGEX = {')
            dom_regex[-1] = dom_regex[-1].replace(',', '}')
            return data_types, domains, dom_regex

        self._output_buffer = [CONSTANTS_HEADER]
        self._add_comment('AMQP Protocol Frame Prefix')
        self._add_line("AMQP = b'AMQP'")
        self._add_line()

        self._add_comment('AMQP Protocol Version')
        self._add_line('VERSION = ({}, {}, {})'.format(
            self._amqp_json['major-version'],
            self._amqp_json['minor-version'],
            self._amqp_json['revision']))
        self._add_line()

        # Defaults
        self._add_comment('RabbitMQ Defaults')
        self._add_line("DEFAULT_HOST = 'localhost'")
        self._add_line('DEFAULT_PORT = {}'.format(self._amqp_json['port']))
        self._add_line("DEFAULT_USER = 'guest'")
        self._add_line("DEFAULT_PASS = 'guest'")
        self._add_line("DEFAULT_VHOST = '/'")
        self._add_line()

        # Constant
        self._add_comment('AMQP Constants')
        for constant in self._amqp_json['constants']:
            if 'class' not in constant:
                documentation = self._documentation(
                    {'constant': constant['name'].lower()})
                if documentation:
                    self._add_comment(documentation)
                self._add_line('{} = {}'.format(
                    self._dashify(constant['name']), constant['value']))
        self._add_line()

        self._add_comment('Not included in the spec XML or JSON files.')
        self._add_line("FRAME_END_CHAR = b'\\xce'")
        self._add_line('FRAME_HEADER_SIZE = 7')
        self._add_line('FRAME_MAX_SIZE = 131072')
        self._add_line()

        data_type_output, domain_output, domain_regex = _build_domain_output()

        self._add_comment('AMQP data types')
        self._output_buffer += data_type_output
        self._add_line()

        self._add_comment('AMQP domains')
        self._output_buffer += domain_output
        self._add_line()

        self._add_comment('AMQP domain patterns')
        self._output_buffer += domain_regex
        self._add_line()

        self._add_comment('Other constants')
        self._add_line("DEPRECATION_WARNING = '{}'".format(
            self.DEPRECATION_WARNING))
        self._add_line()

        self._write_file(CONSTANTS)

    def _build_exceptions(self):
        LOGGER.info('Generating %s', EXCEPTIONS)

        self._output_buffer = [EXCEPTIONS_HEADER]
        errors = {}
        for constant in self._amqp_json['constants']:
            if 'class' in constant:
                class_name = self._classify(constant['name'])
                if constant['class'] == 'soft-error':
                    extends = 'AMQPSoftError'
                elif constant['class'] == 'hard-error':
                    extends = 'AMQPHardError'
                else:
                    raise ValueError('Unexpected class: %s', constant['class'])
                self._add_line('class AMQP{}({}):'.format(class_name, extends))
                self._add_line('    """')
                documentation = self._documentation(
                    {'constant': constant['name'].lower()})
                if documentation:
                    self._add_comment(documentation, 4, '')
                else:
                    if extends == 'AMQPSoftError':
                        self._add_line('    Undocumented AMQP Soft Error')
                    else:
                        self._add_line('    Undocumented AMQP Hard Error')
                self._add_line()
                self._add_line('    """')
                self._add_line("    name = '%s'" % constant['name'])
                self._add_line('    value = %i' % constant['value'])
                self._add_line()
                self._add_line()
                errors[constant['value']] = class_name

        error_lines = []
        for error_code in errors.keys():
            error_lines.append(
                '          {}: AMQP{},'.format(error_code, errors[error_code]))
        self._add_comment('AMQP Error code to class mapping')
        error_lines[0] = error_lines[0].replace(
            '          ', 'CLASS_MAPPING = {')
        error_lines[-1] = error_lines[-1].replace(',', '}')
        self._output_buffer += error_lines

        self._write_file(EXCEPTIONS)

    def _class_definition(self, name: str) -> dict:
        """Iterates through classes trying to match the name against what was
        passed in.

        """
        for definition in self._amqp_json['classes']:
            if definition['name'] == name:
                for method in definition['methods']:
                    for index, ar in enumerate(method['arguments']):
                        method['arguments'][index].update(
                            self._lookup_field(
                                name, method['name'], ar['name']))
                return definition

    @staticmethod
    def _classify(name: str) -> str:
        """Replace the AMQP constant with a more pythonic classname"""
        return ''.join([part.title() for part in name.split('-')])

    @functools.lru_cache()
    def _commands(self) -> typing.List[str]:
        """Create a list of commands to replace in documentation with links
        to the appropriate classes.

        """
        commands = {'Basic.Properties': 'Basic.Properties'}
        for name in set([c['name'] for c in self._amqp_json['classes']
                         if c['name'] not in CODEGEN_IGNORE_CLASSES]):
            definition = self._class_definition(name)
            for method in definition['methods']:
                method_name = '-'.join(
                    [part.title() for part in method['name'].split('-')])
                command = '{}.{}'.format(name.title(), method_name)
                if '-' in command:
                    commands[command] = command.replace('-', '')
                else:
                    commands[command] = command
        return commands

    @staticmethod
    def _dashify(value: str) -> str:
        """Replace ``-`` with ``_`` for the passed in text"""
        return value.replace('-', '_')

    def _documentation(self, search_path: dict) -> typing.Optional[str]:
        """Find the documentation in the xpath"""
        def strip_whitespace(value: typing.List[lxml.etree.Element]) -> str:
            return ' '.join(
                [dl.strip() for dl in value[0].text.split('\n')]).strip()

        node = self._search_xml(search_path, 'doc')
        if node:
            return strip_whitespace(node)
        elif 'field' in search_path:  # Look elsewhere
            domain = self._domain(search_path['field'])
            if domain and domain.documentation:
                return domain.documentation
            node = self._search_xml(search_path, 'doc')
            if node:
                return strip_whitespace(node)
            node = self._search_xml(
                {'field': search_path['field']}, 'doc', True)
            if node:
                return strip_whitespace(node)
        LOGGER.warning('Could not find documentation for %r', search_path)

    @functools.lru_cache()
    def _domain(self, value: str) -> typing.Optional[Domain]:
        for domain in self._domains():
            if domain.name == value:
                return domain

    @functools.lru_cache()
    def _domains(self) -> typing.List[Domain]:
        values = []
        for value in self._amqp_json['domains']:
            domain = self._codegen_xml.xpath(
                '//amqp/domain[@name="{}"]'.format(value[0]))
            extension = self._extensions_xml.xpath(
                '//rabbitmq/domain[@name="{}"]'.format(value[0]))
            doc = self._codegen_xml.xpath(
                '//amqp/domain[@name="{}"]/doc'.format(value[0]))
            kwargs = {
                'name': value[0],
                'type': value[1],
                'documentation': self._remove_extra_whitespace(doc[0].text)
                if doc else None,
                'label': self._remove_extra_whitespace(
                    domain[0].attrib.get('label', '') if domain else ''),
                'nullable': True,
                'regex': None,
                'max_length': None,
                'default_value': domain[0].attrib.get('default-value', None)
                    if domain else None
            }
            if extension \
                    and extension[0].attrib.get('default-value') is not None:
                kwargs['default_value'] = \
                    extension[0].attrib.get('default-value')

            for assertion in self._codegen_xml.xpath(
                    '//amqp/domain[@name="{}"]/assert'.format(value[0])) or []:
                if assertion.attrib.get('check') == 'length':
                    kwargs['max_length'] = int(assertion.attrib['value'])
                elif assertion.attrib.get('check') == 'notnull':
                    kwargs['nullable'] = False
                elif assertion.attrib.get('check') == 'regexp':
                    kwargs['regex'] = assertion.attrib['value']

            for assertion in self._extensions_xml.xpath(
                    '//rabbitmq/domain[@name="{}"]/assert'.format(
                        value[0])) or []:
                if assertion.attrib.get('check') == 'length':
                    kwargs['max_length'] = int(assertion.attrib['value'])
                elif assertion.attrib.get('check') == 'notnull':
                    kwargs['nullable'] = False
                elif assertion.attrib.get('check') == 'regexp':
                    kwargs['regex'] = assertion.attrib['value']

            values.append(Domain(**kwargs))
        return values

    def _label(self, search_path: dict) -> typing.Optional[str]:
        """Attempt to return the short label documentation"""
        node = self._search_xml(search_path)
        if node and 'label' in node[0].attrib:
            return '{}{}'.format(
                node[0].attrib['label'][0:1].upper(),
                node[0].attrib['label'][1:])
        elif node and node[0].text:
            return '{}{}'.format(node[0].text.strip()[0:1].upper(),
                                 node[0].text.strip()[1:].strip())
        if 'field' in search_path:  # Look in domains and extensions
            domain = self._domain(search_path['field'])
            if domain and domain.label:
                return '{}{}'.format(domain.label[0:1].upper(),
                                     domain.label[1:].strip())
            node = self._search_xml(search_path, only_extensions=True)
            if node and 'label' in node[0].attrib:
                return '{}{}'.format(
                    node[0].attrib['label'][0:1].upper(),
                    node[0].attrib['label'][1:])
            node = self._search_xml({'field': search_path['field']})
            if node and 'label' in node[0].attrib:
                return '{}{}'.format(
                    node[0].attrib['label'][0:1].upper(),
                    node[0].attrib['label'][1:])

        LOGGER.warning('Could not find label for %r', search_path)

    @staticmethod
    def _load_codegen_json() -> dict:
        if not CODEGEN_JSON.exists():
            print('Downloading codegen JSON file to %s.' % CODEGEN_JSON)
            response = requests.get(CODEGEN_JSON_URL)
            if not response.ok:
                print('Error downloading JSON file: {}'.format(response))
                sys.exit(1)
            with CODEGEN_JSON.open('w') as handle:
                handle.write(response.content.decode('utf-8'))
        with CODEGEN_JSON.open('r') as handle:
            return json.load(handle)

    @staticmethod
    def _load_codegen_xml() -> lxml.etree.Element:
        if not CODEGEN_XML.exists():
            print('Downloading codegen XML file.')
            response = requests.get(CODEGEN_XML_URL)
            if not response.ok:
                print('Error downloading XML file: {}'.format(response))
                sys.exit(1)
            with CODEGEN_XML.open('w') as handle:
                handle.write(response.content.decode('utf-8'))

        with CODEGEN_XML.open('r') as handle:
            return lxml.etree.parse(handle).xpath('//amqp')[0]

    @staticmethod
    def _load_extensions_xml() -> lxml.etree.Element:
        with EXTENSIONS_XML.open('r') as handle:
            return lxml.etree.parse(handle).xpath('//rabbitmq')[0]

    def _lookup_field(self, class_name: str, method: str, field: str) -> dict:
        field_def = {}
        result = self._codegen_xml.xpath(
            '//amqp/class[@name="{}"]/method[@name="{}"]/'
            'field[@name="{}"]'.format(class_name, method, field))
        if result:
            for attribute in result[0].attrib:
                field_def[attribute] = result[0].attrib[attribute]
        result = self._extensions_xml.xpath(
            '//rabbitmq/class[@name="{}"]/method[@name="{}"]/'
            'field[@name="{}"]'.format(class_name, method, field))
        if result:
            for attribute in result[0].attrib:
                field_def[attribute] = result[0].attrib[attribute]
        return field_def

    @staticmethod
    def _pep8_class_name(value: str) -> str:
        """Returns a class name in the proper case per PEP8"""
        return_value = []
        parts = value.split('-')
        for part in parts:
            return_value.append(part[0:1].upper() + part[1:])
        return ''.join(return_value)

    @staticmethod
    def _remove_extra_whitespace(value: str) -> str:
        return ' '.join([dl.strip() for dl in value.split('\n')]).strip()

    def reset_output_buffer(self) -> None:
        self._output_buffer = []

    def _search_xml(self, search_path: dict,
                    suffix: typing.Optional[str] = None,
                    only_extensions: bool = False) \
            -> typing.Optional[lxml.etree.Element]:
        search = []
        for key in XPATH_ORDER:
            if key in search_path:
                search.append('%s[@name="%s"]' % (key, search_path[key]))
        if suffix:
            search.append(suffix)
        #  LOGGER.debug('Searching for %s', '/'.join(search))
        if only_extensions:
            return self._extensions_xml.xpath('/'.join(search))
        codegen = self._codegen_xml.xpath('/'.join(search))
        extension = self._extensions_xml.xpath('/'.join(search))
        if extension:
            codegen.extend(extension)
        return codegen

    def _write_file(self, path: pathlib.Path) -> None:
        if self.YAPF:
            code = yapf_api.FormatCode(
                '\n'.join(self._output_buffer), style_config='pep8')[0]
        else:
            code = '\n'.join(self._output_buffer)
        with path.open('w') as handle:
            handle.write(code)


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)
    Codegen().build()
