#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import unicode_literals

import mmap
import sys

from maxminddb.compat import byte_from_int, int_from_byte
from maxminddb.decoder import Decoder

if sys.version_info[:2] == (2, 6):
    import unittest2 as unittest
else:
    import unittest

if sys.version_info[0] == 2:
    unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
    unittest.TestCase.assertRegex = unittest.TestCase.assertRegexpMatches


class TestDecoder(unittest.TestCase):
    def test_arrays(self):
        arrays = {
            b'\x00\x04': [],
            b'\x01\x04\x43\x46\x6f\x6f': ['Foo'],
            b'\x02\x04\x43\x46\x6f\x6f\x43\xe4\xba\xba': ['Foo', '人'],
        }
        self.validate_type_decoding('arrays', arrays)

    def test_boolean(self):
        booleans = {
            b"\x00\x07": False,
            b"\x01\x07": True,
        }
        self.validate_type_decoding('booleans', booleans)

    def test_double(self):
        doubles = {
            b"\x68\x00\x00\x00\x00\x00\x00\x00\x00": 0.0,
            b"\x68\x3F\xE0\x00\x00\x00\x00\x00\x00": 0.5,
            b"\x68\x40\x09\x21\xFB\x54\x44\x2E\xEA": 3.14159265359,
            b"\x68\x40\x5E\xC0\x00\x00\x00\x00\x00": 123.0,
            b"\x68\x41\xD0\x00\x00\x00\x07\xF8\xF4": 1073741824.12457,
            b"\x68\xBF\xE0\x00\x00\x00\x00\x00\x00": -0.5,
            b"\x68\xC0\x09\x21\xFB\x54\x44\x2E\xEA": -3.14159265359,
            b"\x68\xC1\xD0\x00\x00\x00\x07\xF8\xF4": -1073741824.12457,
        }
        self.validate_type_decoding('double', doubles)

    def test_float(self):
        floats = {
            b"\x04\x08\x00\x00\x00\x00": 0.0,
            b"\x04\x08\x3F\x80\x00\x00": 1.0,
            b"\x04\x08\x3F\x8C\xCC\xCD": 1.1,
            b"\x04\x08\x40\x48\xF5\xC3": 3.14,
            b"\x04\x08\x46\x1C\x3F\xF6": 9999.99,
            b"\x04\x08\xBF\x80\x00\x00": -1.0,
            b"\x04\x08\xBF\x8C\xCC\xCD": -1.1,
            b"\x04\x08\xC0\x48\xF5\xC3": -3.14,
            b"\x04\x08\xC6\x1C\x3F\xF6": -9999.99
        }
        self.validate_type_decoding('float', floats)

    def test_int32(self):
        int32 = {
            b"\x00\x01": 0,
            b"\x04\x01\xff\xff\xff\xff": -1,
            b"\x01\x01\xff": 255,
            b"\x04\x01\xff\xff\xff\x01": -255,
            b"\x02\x01\x01\xf4": 500,
            b"\x04\x01\xff\xff\xfe\x0c": -500,
            b"\x02\x01\xff\xff": 65535,
            b"\x04\x01\xff\xff\x00\x01": -65535,
            b"\x03\x01\xff\xff\xff": 16777215,
            b"\x04\x01\xff\x00\x00\x01": -16777215,
            b"\x04\x01\x7f\xff\xff\xff": 2147483647,
            b"\x04\x01\x80\x00\x00\x01": -2147483647,
        }
        self.validate_type_decoding('int32', int32)

    def test_map(self):
        maps = {
            b'\xe0': {},
            b'\xe1\x42\x65\x6e\x43\x46\x6f\x6f': {
                'en': 'Foo'
            },
            b'\xe2\x42\x65\x6e\x43\x46\x6f\x6f\x42\x7a\x68\x43\xe4\xba\xba': {
                'en': 'Foo',
                'zh': '人'
            },
            (b'\xe1\x44\x6e\x61\x6d\x65\xe2\x42\x65\x6e'
             b'\x43\x46\x6f\x6f\x42\x7a\x68\x43\xe4\xba\xba'): {
                 'name': {
                     'en': 'Foo',
                     'zh': '人'
                 }
             },
            (b'\xe1\x49\x6c\x61\x6e\x67\x75\x61\x67\x65\x73'
             b'\x02\x04\x42\x65\x6e\x42\x7a\x68'): {
                 'languages': ['en', 'zh']
             },
        }
        self.validate_type_decoding('maps', maps)

    def test_pointer(self):
        pointers = {
            b'\x20\x00': 0,
            b'\x20\x05': 5,
            b'\x20\x0a': 10,
            b'\x23\xff': 1023,
            b'\x28\x03\xc9': 3017,
            b'\x2f\xf7\xfb': 524283,
            b'\x2f\xff\xff': 526335,
            b'\x37\xf7\xf7\xfe': 134217726,
            b'\x37\xff\xff\xff': 134744063,
            b'\x38\x7f\xff\xff\xff': 2147483647,
            b'\x38\xff\xff\xff\xff': 4294967295,
        }
        self.validate_type_decoding('pointers', pointers)

    strings = {
        b"\x40": '',
        b"\x41\x31": '1',
        b"\x43\xE4\xBA\xBA": '人',
        (b"\x5b\x31\x32\x33\x34"
         b"\x35\x36\x37\x38\x39\x30\x31\x32\x33\x34\x35"
         b"\x36\x37\x38\x39\x30\x31\x32\x33\x34\x35\x36\x37"):
        '123456789012345678901234567',
        (b"\x5c\x31\x32\x33\x34"
         b"\x35\x36\x37\x38\x39\x30\x31\x32\x33\x34\x35"
         b"\x36\x37\x38\x39\x30\x31\x32\x33\x34\x35\x36"
         b"\x37\x38"): '1234567890123456789012345678',
        (b"\x5d\x00\x31\x32\x33"
         b"\x34\x35\x36\x37\x38\x39\x30\x31\x32\x33\x34"
         b"\x35\x36\x37\x38\x39\x30\x31\x32\x33\x34\x35"
         b"\x36\x37\x38\x39"): '12345678901234567890123456789',
        (b"\x5d\x01\x31\x32\x33"
         b"\x34\x35\x36\x37\x38\x39\x30\x31\x32\x33\x34"
         b"\x35\x36\x37\x38\x39\x30\x31\x32\x33\x34\x35"
         b"\x36\x37\x38\x39\x30"): '123456789012345678901234567890',
        b'\x5e\x00\xd7' + 500 * b'\x78': 'x' * 500,
        b'\x5e\x06\xb3' + 2000 * b'\x78': 'x' * 2000,
        b'\x5f\x00\x10\x53' + 70000 * b'\x78': 'x' * 70000,
    }

    def test_string(self):
        self.validate_type_decoding('string', self.strings)

    def test_byte(self):
        # Python 2.6 doesn't support dictionary comprehension
        b = dict((byte_from_int(0xc0 ^ int_from_byte(k[0])) + k[1:],
                  v.encode('utf-8')) for k, v in self.strings.items())
        self.validate_type_decoding('byte', b)

    def test_uint16(self):
        uint16 = {
            b"\xa0": 0,
            b"\xa1\xff": 255,
            b"\xa2\x01\xf4": 500,
            b"\xa2\x2a\x78": 10872,
            b"\xa2\xff\xff": 65535,
        }
        self.validate_type_decoding('uint16', uint16)

    def test_uint32(self):
        uint32 = {
            b"\xc0": 0,
            b"\xc1\xff": 255,
            b"\xc2\x01\xf4": 500,
            b"\xc2\x2a\x78": 10872,
            b"\xc2\xff\xff": 65535,
            b"\xc3\xff\xff\xff": 16777215,
            b"\xc4\xff\xff\xff\xff": 4294967295,
        }
        self.validate_type_decoding('uint32', uint32)

    def generate_large_uint(self, bits):
        ctrl_byte = b'\x02' if bits == 64 else b'\x03'
        uints = {
            b'\x00' + ctrl_byte: 0,
            b'\x02' + ctrl_byte + b'\x01\xf4': 500,
            b'\x02' + ctrl_byte + b'\x2a\x78': 10872,
        }
        for power in range(bits // 8 + 1):
            expected = 2**(8 * power) - 1
            input = byte_from_int(power) + ctrl_byte + (b'\xff' * power)
            uints[input] = expected
        return uints

    def test_uint64(self):
        self.validate_type_decoding('uint64', self.generate_large_uint(64))

    def test_uint128(self):
        self.validate_type_decoding('uint128', self.generate_large_uint(128))

    def validate_type_decoding(self, type, tests):
        for input, expected in tests.items():
            self.check_decoding(type, input, expected)

    def check_decoding(self, type, input, expected, name=None):

        name = name or expected
        db = mmap.mmap(-1, len(input))
        db.write(input)

        decoder = Decoder(db, pointer_test=True)
        (
            actual,
            _, ) = decoder.decode(0)

        if type in ('float', 'double'):
            self.assertAlmostEqual(expected, actual, places=3, msg=type)
        else:
            self.assertEqual(expected, actual, type)

    def test_real_pointers(self):
        with open('tests/data/test-data/maps-with-pointers.raw',
                  'r+b') as db_file:
            mm = mmap.mmap(db_file.fileno(), 0)
            decoder = Decoder(mm, 0)

            self.assertEqual(({
                'long_key': 'long_value1'
            }, 22), decoder.decode(0))

            self.assertEqual(({
                'long_key': 'long_value2'
            }, 37), decoder.decode(22))

            self.assertEqual(({
                'long_key2': 'long_value1'
            }, 50), decoder.decode(37))

            self.assertEqual(({
                'long_key2': 'long_value2'
            }, 55), decoder.decode(50))

            self.assertEqual(({
                'long_key': 'long_value1'
            }, 57), decoder.decode(55))

            self.assertEqual(({
                'long_key2': 'long_value2'
            }, 59), decoder.decode(57))

            mm.close()
