# -*- coding: utf-8 -*-

import decimal
from decimal import Decimal as D
from hashlib import md5
import pickle
import sys
import unittest
import uuid

import pytest

import pymssql
from .helpers import get_sql_server_version, mssql_server_required
from .helpers import drop_table, mssqlconn, clear_table, config, eq_, pymssqlconn


def get_bytes_buffer():
    from io import BytesIO
    return BytesIO()


def typeeq(v1, v2):
    eq_(type(v1), type(v2))


create_test_table_sql = [
    'CREATE TABLE pymssql ('
    '    pk_id int IDENTITY (1, 1) NOT NULL,',
    '    real_no real,',
    '    float_no float,',
    '    money_no money,',
    '    data_bit bit,',
    '    comment_vch varchar(50),',
    '    comment_nvch nvarchar(50),',
    '    comment_text text,',
    '    comment_ntext ntext,',
    '    data_image image,',
    '    data_binary varbinary(40),',
    '    decimal_no decimal(38,2),',
    '    decimal_no2 decimal(38,10),',
    '    numeric_no numeric(38,8),',
    '    uuid uniqueidentifier',
    ')',
]


def build_create_table_query(conn):
    create_sql = '\n'.join(create_test_table_sql)
    return create_sql


@pytest.mark.mssql_server_required
class TestTypes(unittest.TestCase):
    tname = 'pymssql'

    @classmethod
    def setup_class(cls):
        cls.conn = mssqlconn()
        drop_table(cls.conn, cls.tname)
        ddl_str = build_create_table_query(cls.conn)
        cls.conn.execute_non_query(ddl_str)

    def setUp(self):
        clear_table(self.conn, self.tname)

    def hasheq(self, v1, v2):
        if hasattr(v1, 'encode'):
            v1 = v1.encode('utf-8')
        if hasattr(v2, 'encode'):
            v2 = v2.encode('utf-8')

        hd1 = md5(v1).hexdigest()
        hd2 = md5(v2).hexdigest()
        assert hd1 == hd2, '%s (%s) != %s (%s)' % (v1, hd1, v2, hd2)

    def insert_and_select(self, cname, value, vartype, params_as_dict=False):
        if params_as_dict:
            inssql = 'insert into %s (%s) values (%%(value)%s)' % (self.tname, cname, vartype)
            self.conn.execute_non_query(inssql, dict(value=value))
        else:
            inssql = 'insert into %s (%s) values (%%%s)' % (self.tname, cname, vartype)
            self.conn.execute_non_query(inssql, value)
        self.conn.execute_query('select %s from pymssql' % cname)
        rows = tuple(self.conn)
        eq_(len(rows), 1)
        cval = rows[0][cname]
        return cval

    def test_varchar(self):
        testval = 'foobar'
        colval = self.insert_and_select('comment_vch', testval, 's')
        typeeq('foobar', colval)
        self.hasheq('foobar', colval)

    def test_varchar_hex(self):
        testval = '0xf00'
        colval = self.insert_and_select('comment_vch', testval, 's')
        typeeq('0xf00', colval)
        self.hasheq('0xf00', colval)

    def test_varchar_unicode(self):
        testval = 'foobär'
        colval = self.insert_and_select('comment_vch', testval, 's')
        typeeq('foobär', colval)
        eq_('foobär', colval)

    def test_nvarchar_unicode(self):
        testval = 'foobär'
        colval = self.insert_and_select('comment_nvch', testval, 's')
        typeeq(testval, colval)
        eq_(testval, colval)

    def test_binary_bytearray(self):
        bindata = '{z\n\x03\x07\x194;\x034lE4ISo'.encode('ascii')
        colval = self.insert_and_select('data_binary', bytearray(bindata), 's')
        typeeq(bindata, colval)
        eq_(bindata, colval)

    def test_binary_Binary(self):
        """
        See https://github.com/pymssql/pymssql/issues/504
        """
        bindata = '{z\n\x03\x07\x194;\x034lE4ISo'.encode('ascii')
        colval = self.insert_and_select('data_binary', pymssql.Binary(bindata), 's')
        typeeq(bindata, colval)
        eq_(bindata, colval)

    def test_image(self):
        buf = get_bytes_buffer()
        longstr = 'a'*4000
        pickle.dump([1, 2, longstr], buf, -1)
        testval = buf.getvalue()
        colval = self.insert_and_select('data_image', testval, 's')
        typeeq(testval, colval)
        self.hasheq(testval, colval)
        tlist = pickle.loads(colval)
        eq_(tlist, [1, 2, longstr])

    def test_image_gt_4KB(self):
        """
            test_image_gt_4KB

            By default, SQL server sets TEXTSIZE = 4096 bytes.  We up that by
            default and want to make sure it applies.
        """
        buf = get_bytes_buffer()
        longstr = 'a'*5000
        pickle.dump([1, 2, longstr], buf, -1)
        testval = buf.getvalue()
        colval = self.insert_and_select('data_image', testval, 's')
        typeeq(testval, colval)
        self.hasheq(testval, colval)
        tlist = pickle.loads(colval)
        eq_(tlist, [1, 2, longstr])

    def test_decimal(self):
        # test rounding down
        origval = D('1.2345')
        expect = D('1.23')
        colval = self.insert_and_select('decimal_no', origval, 's')
        typeeq(expect, colval)
        eq_(expect, colval)

    def test_decimal_context_protection(self):
        origval = D('1.2345')
        colval = self.insert_and_select('decimal_no', origval, 's')

        # make sure our manipulation of the decimal values doesn't affect the
        # default decimal context
        eq_(decimal.getcontext().prec, config.orig_decimal_prec)

    def test_decimal_rounding_up(self):
        # test rounding up
        origval = D('1.235')
        expect = D('1.24')
        colval = self.insert_and_select('decimal_no', origval, 's')
        typeeq(expect, colval)
        eq_(expect, colval)

    def test_decimal_smaller_precision(self):
        # smaller precision than column
        origval = D('1.2345')
        expect = D('1.2345000000')
        colval = self.insert_and_select('decimal_no2', origval, 's')
        typeeq(expect, colval)
        eq_(expect, colval)

    def test_numeric(self):
        # should be handled the same as a decimal column, so only one test
        origval = D('1.2345')
        expect = D('1.23450000')
        colval = self.insert_and_select('numeric_no', origval, 's')
        typeeq(expect, colval)
        eq_(expect, colval)

    def test_float_precision(self):
        #origval and expect are not exactly the same, but they test
        # equal and that is what we are getting at.  They should have
        # the same value out to the 16th digit.
        origval = 1.23456789012345670
        expect = 1.23456789012345671
        colval = self.insert_and_select('float_no', origval, 's')
        typeeq(expect, colval)
        eq_(expect, colval)

    def test_uuid_passed_as_string(self):
        """
        Test the case when the application passes a string representation of
        the uuid.UUID data type to _mssql.
        """
        origval = uuid.uuid4()
        stestval = str(origval)
        colval = self.insert_and_select('uuid', stestval, 's')
        typeeq(origval, colval)
        eq_(origval, colval)

    def test_uuid_passed_as_python_datatype(self):
        """
        Test the case when the application passes an instance of the uuid.UUID
        data type to _mssql to confirm it can handle it.
        """
        origval = uuid.uuid4()
        colval = self.insert_and_select('uuid', origval, 's')
        typeeq(origval, colval)
        eq_(origval, colval)


@pytest.mark.mssql_server_required
class TestTypesPymssql(unittest.TestCase):
    tname = 'pymssql'

    @classmethod
    def setup_class(cls):
        cls.conn = pymssqlconn()
        drop_table(cls.conn._conn, cls.tname)
        ddl_str = build_create_table_query(cls.conn._conn)
        with cls.conn.cursor() as c:
            c.execute(ddl_str)

    def setUp(self):
        clear_table(self.conn._conn, self.tname)

    def insert_and_select(self, cname, value, vartype):
        with self.conn.cursor() as c:
            inssql = 'insert into %s (%s) values (%%%s)' % (self.tname, cname, vartype)
            c.execute(inssql, value)
            c.execute('select %s from %s' % (cname, self.tname))
            rows = c.fetchall()
            eq_(len(rows), 1)
            cval = rows[0][0]
        return cval

    def test_uuid_passed_as_string(self):
        """
        The uuid.UUID type isn't supported by the pyformat paramstyle (that
        only supports '%s' and is the style supported by pymsql).
        Test the case when the application passes a string representation of
        such data type to _mssql.
        """
        origval = uuid.uuid4()
        stestval = str(origval)
        colval = self.insert_and_select('uuid', stestval, 's')
        typeeq(origval, colval)
        eq_(origval, colval)

    def test_uuid_passed_as_python_datatype(self):
        """
        The uuid.UUID type isn't supported by the pyformat paramstyle (that
        only supports '%s' and is the style supported by pymsql).
        Test the case when the application passes an instance of such data type
        to _mssql to confirm it can handle it.
        """
        origval = uuid.uuid4()
        colval = self.insert_and_select('uuid', origval, 's')
        typeeq(origval, colval)
        eq_(origval, colval)
