from contextlib import contextmanager
import os
from stone.backend import CodeBackend
from stone.backends.swift_helpers import (
    fmt_class,
    fmt_func,
    fmt_obj,
    fmt_type,
    fmt_var,
    fmt_objc_type,
    mapped_list_info,
)

from stone.ir import (
    Boolean,
    Bytes,
    Float32,
    Float64,
    Int32,
    Int64,
    List,
    Map,
    String,
    Timestamp,
    UInt32,
    UInt64,
    Void,
    is_list_type,
    is_map_type,
    is_timestamp_type,
    is_union_type,
    is_user_defined_type,
    unwrap_nullable,
    is_nullable_type,
)

_serial_type_table = {
    Boolean: 'BoolSerializer',
    Bytes: 'NSDataSerializer',
    Float32: 'FloatSerializer',
    Float64: 'DoubleSerializer',
    Int32: 'Int32Serializer',
    Int64: 'Int64Serializer',
    List: 'ArraySerializer',
    Map: 'DictionarySerializer',
    String: 'StringSerializer',
    Timestamp: 'NSDateSerializer',
    UInt32: 'UInt32Serializer',
    UInt64: 'UInt64Serializer',
    Void: 'VoidSerializer',
}

_nsnumber_type_table = {
    Boolean: '.boolValue',
    Bytes: '',
    Float32: '.floatValue',
    Float64: '.doubleValue',
    Int32: '.int32Value',
    Int64: '.int64Value',
    List: '',
    String: '',
    Timestamp: '',
    UInt32: '.uint32Value',
    UInt64: '.uint64Value',
    Void: '',
    Map: '',
}

stone_warning = """\
///
/// Copyright (c) 2016 Dropbox, Inc. All rights reserved.
///
/// Auto-generated by Stone, do not modify.
///

"""

# This will be at the top of the generated file.
base = """\
{}\
import Foundation

""".format(stone_warning)


undocumented = '(no description)'


class SwiftBaseBackend(CodeBackend):
    """Wrapper class over Stone generator for Swift logic."""
    # pylint: disable=abstract-method

    @contextmanager
    def function_block(self, func, args, return_type=None):
        signature = '{}({})'.format(func, args)
        if return_type:
            signature += ' -> {}'.format(return_type)
        with self.block(signature):
            yield

    def _func_args(self, args_list, newlines=False, force_first=False, not_init=False):
        out = []
        first = True
        for k, v in args_list:
            # this is a temporary hack -- injected client-side args
            # do not have a separate field for default value. Right now,
            # default values are stored along with the type, e.g.
            # `Bool = True` is a type, hence this check.
            if first and force_first and '=' not in v:
                k = "{0} {0}".format(k)

            if first and v is not None and not_init:
                out.append('{}'.format(v))
            elif v is not None:
                out.append('{}: {}'.format(k, v))
            first = False
        sep = ', '
        if newlines:
            sep += '\n' + self.make_indent()
        return sep.join(out)

    def _struct_init_args(self, data_type, namespace=None):  # pylint: disable=unused-argument
        args = []
        for field in data_type.all_fields:
            name = fmt_var(field.name)
            value = fmt_type(field.data_type)
            data_type, nullable = unwrap_nullable(field.data_type)

            if field.has_default:
                if is_union_type(data_type):
                    default = '.{}'.format(fmt_var(field.default.tag_name))
                else:
                    default = fmt_obj(field.default)
                value += ' = {}'.format(default)
            elif nullable:
                value += ' = nil'
            arg = (name, value)
            args.append(arg)
        return args

    def _objc_init_args(self, data_type, include_defaults=True):
        args = []
        for field in data_type.all_fields:
            name = fmt_var(field.name)
            value = fmt_objc_type(field.data_type)
            data_type, nullable = unwrap_nullable(field.data_type)

            if not include_defaults and (field.has_default or nullable):
                continue

            arg = (name, value)
            args.append(arg)
        return args

    def _objc_no_defualts_func_args(self, data_type, args_data=None):
        args = []
        for field in data_type.all_fields:
            name = fmt_var(field.name)
            _, nullable = unwrap_nullable(field.data_type)
            if field.has_default or nullable:
                continue
            arg = (name, name)
            args.append(arg)

        if args_data is not None:
            _, type_data_list = tuple(args_data)
            extra_args = [tuple(type_data[:-1]) for type_data in type_data_list]
            for name, _, extra_type in extra_args:
                if not is_nullable_type(extra_type):
                    arg = (name, name)
                    args.append(arg)

        return self._func_args(args)

    def _objc_init_args_to_swift(self, data_type, args_data=None, include_defaults=True):
        args = []
        for field in data_type.all_fields:
            name = fmt_var(field.name)
            field_data_type, nullable = unwrap_nullable(field.data_type)
            if not include_defaults and (field.has_default or nullable):
                continue
            nsnumber_type = _nsnumber_type_table.get(field_data_type.__class__)
            value = '{}{}{}'.format(name,
                                    '?' if nullable and nsnumber_type else '',
                                    nsnumber_type)
            if is_list_type(field_data_type):
                _, prefix, suffix, list_data_type, _ = mapped_list_info(field_data_type)

                value = '{}{}'.format(name,
                                      '?' if nullable else '')
                list_nsnumber_type = _nsnumber_type_table.get(list_data_type.__class__)

                if not is_user_defined_type(list_data_type) and not list_nsnumber_type:
                    value = name
                else:
                    value = '{}.map {}'.format(value,
                                               prefix)

                    if is_user_defined_type(list_data_type):
                        value = '{}{{ $0.{} }}'.format(value,
                                                       self._objc_swift_var_name(list_data_type))
                    else:
                        value = '{}{{ $0{} }}'.format(value,
                                                      list_nsnumber_type)

                    value = '{}{}'.format(value,
                                          suffix)
            elif is_map_type(field_data_type):
                if is_user_defined_type(field_data_type.value_data_type):
                    value = '{}{}.mapValues {{ $0.swift }}'.format(name,
                                                                   '?' if nullable else '')
            elif is_user_defined_type(field_data_type):
                value = '{}{}.{}'.format(name,
                                         '?' if nullable else '',
                                         self._objc_swift_var_name(field_data_type))

            arg = (name, value)
            args.append(arg)

        if args_data is not None:
            _, type_data_list = tuple(args_data)
            extra_args = [tuple(type_data[:-1]) for type_data in type_data_list]
            for name, _, _ in extra_args:
                args.append((name, name))

        return self._func_args(args)

    def _objc_swift_var_name(self, data_type):
        parent_type = data_type.parent_type
        uw_parent_type, _ = unwrap_nullable(parent_type)
        sub_count = 1 if parent_type else 0
        while is_user_defined_type(uw_parent_type) and parent_type.parent_type:
            sub_count += 1
            parent_type = parent_type.parent_type
            uw_parent_type, _ = unwrap_nullable(parent_type)

        if sub_count == 0 or is_union_type(data_type):
            return 'swift'
        else:
            name = 'Swift'
            i = 1
            while i <= sub_count:
                name = '{}{}'.format('sub' if i == sub_count else 'Sub',
                                     name)
                i += 1
            return name

    def _docf(self, tag, val):
        if tag == 'route':
            if ':' in val:
                val, version = val.split(':', 1)
                version = int(version)
            else:
                version = 1
            return fmt_func(val, version)
        elif tag == 'field':
            if '.' in val:
                cls, field = val.split('.')
                return ('{} in {}'.format(fmt_var(field),
                        fmt_class(cls)))
            else:
                return fmt_var(val)
        elif tag in ('type', 'val', 'link'):
            return val
        else:
            return val

    def _write_output_in_target_folder(self, output, file_name):
        full_path = self.target_folder_path
        if not os.path.exists(full_path):
            os.mkdir(full_path)
        full_path = os.path.join(full_path, file_name)
        with open(full_path, "w", encoding='utf-8') as fh:
            fh.write(output)

def fmt_serial_type(data_type):
    data_type, nullable = unwrap_nullable(data_type)

    if is_user_defined_type(data_type):
        result = '{}.{}Serializer'
        result = result.format(fmt_class(data_type.namespace.name),
            fmt_class(data_type.name))
    else:
        result = _serial_type_table.get(data_type.__class__, fmt_class(data_type.name))

        if is_list_type(data_type):
            result = result + '<{}>'.format(fmt_serial_type(data_type.data_type))
        if is_map_type(data_type):
            result = result + '<{}, {}>'.format(fmt_serial_type(data_type.key_data_type),
            fmt_serial_type(data_type.value_data_type))

    return result if not nullable else 'NullableSerializer'


def fmt_serial_obj(data_type):
    data_type, nullable = unwrap_nullable(data_type)

    if is_user_defined_type(data_type):
        result = '{}.{}Serializer()'
        result = result.format(fmt_class(data_type.namespace.name),
            fmt_class(data_type.name))
    else:
        result = _serial_type_table.get(data_type.__class__, fmt_class(data_type.name))

        if is_list_type(data_type):
            result = result + '({})'.format(fmt_serial_obj(data_type.data_type))
        elif is_map_type(data_type):
            result = result + '({})'.format(fmt_serial_obj(data_type.value_data_type))
        elif is_timestamp_type(data_type):
            result = result + '("{}")'.format(data_type.format)
        else:
            result = 'Serialization._{}'.format(result)

    return result if not nullable else 'NullableSerializer({})'.format(result)
