# -*- coding: utf-8 -*-

import pytest

from tests.utils import TestCaseBase

import sqlparse
from sqlparse.exceptions import SQLParseError


class TestFormat(TestCaseBase):

    def test_keywordcase(self):
        sql = 'select * from bar; -- select foo\n'
        res = sqlparse.format(sql, keyword_case='upper')
        self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n')
        res = sqlparse.format(sql, keyword_case='capitalize')
        self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n')
        res = sqlparse.format(sql.upper(), keyword_case='lower')
        self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n')
        self.assertRaises(SQLParseError, sqlparse.format, sql,
                          keyword_case='foo')

    def test_identifiercase(self):
        sql = 'select * from bar; -- select foo\n'
        res = sqlparse.format(sql, identifier_case='upper')
        self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n')
        res = sqlparse.format(sql, identifier_case='capitalize')
        self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n')
        res = sqlparse.format(sql.upper(), identifier_case='lower')
        self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n')
        self.assertRaises(SQLParseError, sqlparse.format, sql,
                          identifier_case='foo')
        sql = 'select * from "foo"."bar"'
        res = sqlparse.format(sql, identifier_case="upper")
        self.ndiffAssertEqual(res, 'select * from "foo"."bar"')

    def test_strip_comments_single(self):
        sql = 'select *-- statement starts here\nfrom foo'
        res = sqlparse.format(sql, strip_comments=True)
        self.ndiffAssertEqual(res, 'select * from foo')
        sql = 'select * -- statement starts here\nfrom foo'
        res = sqlparse.format(sql, strip_comments=True)
        self.ndiffAssertEqual(res, 'select * from foo')
        sql = 'select-- foo\nfrom -- bar\nwhere'
        res = sqlparse.format(sql, strip_comments=True)
        self.ndiffAssertEqual(res, 'select from where')
        self.assertRaises(SQLParseError, sqlparse.format, sql,
                          strip_comments=None)

    def test_strip_comments_multi(self):
        sql = '/* sql starts here */\nselect'
        res = sqlparse.format(sql, strip_comments=True)
        self.ndiffAssertEqual(res, 'select')
        sql = '/* sql starts here */ select'
        res = sqlparse.format(sql, strip_comments=True)
        self.ndiffAssertEqual(res, 'select')
        sql = '/*\n * sql starts here\n */\nselect'
        res = sqlparse.format(sql, strip_comments=True)
        self.ndiffAssertEqual(res, 'select')
        sql = 'select (/* sql starts here */ select 2)'
        res = sqlparse.format(sql, strip_comments=True)
        self.ndiffAssertEqual(res, 'select (select 2)')

    def test_strip_ws(self):
        f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
        s = 'select\n* from      foo\n\twhere  ( 1 = 2 )\n'
        self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)')
        s = 'select -- foo\nfrom    bar\n'
        self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar')
        self.assertRaises(SQLParseError, sqlparse.format, s,
                          strip_whitespace=None)

    def test_preserve_ws(self):
        # preserve at least one whitespace after subgroups
        f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
        s = 'select\n* /* foo */  from bar '
        self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar')

    def test_notransform_of_quoted_crlf(self):
        # Make sure that CR/CR+LF characters inside string literals don't get
        # affected by the formatter.

        s1 = "SELECT some_column LIKE 'value\r'"
        s2 = "SELECT some_column LIKE 'value\r'\r\nWHERE id = 1\n"
        s3 = "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\r"
        s4 = "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\r\n"

        f = lambda x: sqlparse.format(x)

        # Because of the use of
        self.ndiffAssertEqual(f(s1), "SELECT some_column LIKE 'value\r'")
        self.ndiffAssertEqual(f(s2), "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n")
        self.ndiffAssertEqual(f(s3), "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n")
        self.ndiffAssertEqual(f(s4), "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n")

    def test_outputformat(self):
        sql = 'select * from foo;'
        self.assertRaises(SQLParseError, sqlparse.format, sql,
                          output_format='foo')


class TestFormatReindent(TestCaseBase):

    def test_option(self):
        self.assertRaises(SQLParseError, sqlparse.format, 'foo',
                          reindent=2)
        self.assertRaises(SQLParseError, sqlparse.format, 'foo',
                          indent_tabs=2)
        self.assertRaises(SQLParseError, sqlparse.format, 'foo',
                          reindent=True, indent_width='foo')
        self.assertRaises(SQLParseError, sqlparse.format, 'foo',
                          reindent=True, indent_width=-12)

    def test_stmts(self):
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'select foo; select bar'
        self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar')
        s = 'select foo'
        self.ndiffAssertEqual(f(s), 'select foo')
        s = 'select foo; -- test\n select bar'
        self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar')

    def test_keywords(self):
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'select * from foo union select * from bar;'
        self.ndiffAssertEqual(f(s), '\n'.join(['select *',
                                               'from foo',
                                               'union',
                                               'select *',
                                               'from bar;']))

    def test_keywords_between(self):  # issue 14
        # don't break AND after BETWEEN
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'and foo between 1 and 2 and bar = 3'
        self.ndiffAssertEqual(f(s), '\n'.join(['',
                                               'and foo between 1 and 2',
                                               'and bar = 3']))

    def test_parenthesis(self):
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'select count(*) from (select * from foo);'
        self.ndiffAssertEqual(f(s),
                              '\n'.join(['select count(*)',
                                         'from',
                                         '  (select *',
                                         '   from foo);',
                                         ])
                              )

    def test_where(self):
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;'
        self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
                                     'where bar = 1\n'
                                     '  and baz = 2\n'
                                     '  or bzz = 3;'))
        s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);'
        self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
                                     'where bar = 1\n'
                                     '  and (baz = 2\n'
                                     '       or bzz = 3);'))

    def test_join(self):
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'select * from foo join bar on 1 = 2'
        self.ndiffAssertEqual(f(s), '\n'.join(['select *',
                                               'from foo',
                                               'join bar on 1 = 2']))
        s = 'select * from foo inner join bar on 1 = 2'
        self.ndiffAssertEqual(f(s), '\n'.join(['select *',
                                               'from foo',
                                               'inner join bar on 1 = 2']))
        s = 'select * from foo left outer join bar on 1 = 2'
        self.ndiffAssertEqual(f(s), '\n'.join(['select *',
                                               'from foo',
                                               'left outer join bar on 1 = 2']
                                              ))
        s = 'select * from foo straight_join bar on 1 = 2'
        self.ndiffAssertEqual(f(s), '\n'.join(['select *',
                                               'from foo',
                                               'straight_join bar on 1 = 2']
                                              ))

    def test_identifier_list(self):
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'select foo, bar, baz from table1, table2 where 1 = 2'
        self.ndiffAssertEqual(f(s), '\n'.join(['select foo,',
                                               '       bar,',
                                               '       baz',
                                               'from table1,',
                                               '     table2',
                                               'where 1 = 2']))
        s = 'select a.*, b.id from a, b'
        self.ndiffAssertEqual(f(s), '\n'.join(['select a.*,',
                                               '       b.id',
                                               'from a,',
                                               '     b']))

    def test_identifier_list_with_functions(self):
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = ("select 'abc' as foo, coalesce(col1, col2)||col3 as bar,"
             "col3 from my_table")
        self.ndiffAssertEqual(f(s), '\n'.join(
            ["select 'abc' as foo,",
             "       coalesce(col1, col2)||col3 as bar,",
             "       col3",
             "from my_table"]))

    def test_case(self):
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end'
        self.ndiffAssertEqual(f(s), '\n'.join(['case',
                                               '    when foo = 1 then 2',
                                               '    when foo = 3 then 4',
                                               '    else 5',
                                               'end']))

    def test_case2(self):
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'case(foo) when bar = 1 then 2 else 3 end'
        self.ndiffAssertEqual(f(s), '\n'.join(['case(foo)',
                                               '    when bar = 1 then 2',
                                               '    else 3',
                                               'end']))

    def test_nested_identifier_list(self):  # issue4
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = '(foo as bar, bar1, bar2 as bar3, b4 as b5)'
        self.ndiffAssertEqual(f(s), '\n'.join(['(foo as bar,',
                                               ' bar1,',
                                               ' bar2 as bar3,',
                                               ' b4 as b5)']))

    def test_duplicate_linebreaks(self):  # issue3
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'select c1 -- column1\nfrom foo'
        self.ndiffAssertEqual(f(s), '\n'.join(['select c1 -- column1',
                                               'from foo']))
        s = 'select c1 -- column1\nfrom foo'
        r = sqlparse.format(s, reindent=True, strip_comments=True)
        self.ndiffAssertEqual(r, '\n'.join(['select c1',
                                            'from foo']))
        s = 'select c1\nfrom foo\norder by c1'
        self.ndiffAssertEqual(f(s), '\n'.join(['select c1',
                                               'from foo',
                                               'order by c1']))
        s = 'select c1 from t1 where (c1 = 1) order by c1'
        self.ndiffAssertEqual(f(s), '\n'.join(['select c1',
                                               'from t1',
                                               'where (c1 = 1)',
                                               'order by c1']))

    def test_keywordfunctions(self):  # issue36
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'select max(a) b, foo, bar'
        self.ndiffAssertEqual(f(s), '\n'.join(['select max(a) b,',
                                               '       foo,',
                                               '       bar']))

    def test_identifier_and_functions(self):  # issue45
        f = lambda sql: sqlparse.format(sql, reindent=True)
        s = 'select foo.bar, nvl(1) from dual'
        self.ndiffAssertEqual(f(s), '\n'.join(['select foo.bar,',
                                               '       nvl(1)',
                                               'from dual']))


class TestOutputFormat(TestCaseBase):

    def test_python(self):
        sql = 'select * from foo;'
        f = lambda sql: sqlparse.format(sql, output_format='python')
        self.ndiffAssertEqual(f(sql), "sql = 'select * from foo;'")
        f = lambda sql: sqlparse.format(sql, output_format='python',
                                        reindent=True)
        self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n"
                                       "       'from foo;')"))

    def test_php(self):
        sql = 'select * from foo;'
        f = lambda sql: sqlparse.format(sql, output_format='php')
        self.ndiffAssertEqual(f(sql), '$sql = "select * from foo;";')
        f = lambda sql: sqlparse.format(sql, output_format='php',
                                        reindent=True)
        self.ndiffAssertEqual(f(sql), ('$sql  = "select * ";\n'
                                       '$sql .= "from foo;";'))

    def test_sql(self):  # "sql" is an allowed option but has no effect
        sql = 'select * from foo;'
        f = lambda sql: sqlparse.format(sql, output_format='sql')
        self.ndiffAssertEqual(f(sql), 'select * from foo;')


def test_format_column_ordering():  # issue89
    sql = 'select * from foo order by c1 desc, c2, c3;'
    formatted = sqlparse.format(sql, reindent=True)
    expected = '\n'.join(['select *',
                          'from foo',
                          'order by c1 desc,',
                          '         c2,',
                          '         c3;'])
    assert formatted == expected


def test_truncate_strings():
    sql = 'update foo set value = \'' + 'x' * 1000 + '\';'
    formatted = sqlparse.format(sql, truncate_strings=10)
    assert formatted == 'update foo set value = \'xxxxxxxxxx[...]\';'
    formatted = sqlparse.format(sql, truncate_strings=3, truncate_char='YYY')
    assert formatted == 'update foo set value = \'xxxYYY\';'


def test_truncate_strings_invalid_option():
    pytest.raises(SQLParseError, sqlparse.format,
                  'foo', truncate_strings='bar')
    pytest.raises(SQLParseError, sqlparse.format,
                  'foo', truncate_strings=-1)
    pytest.raises(SQLParseError, sqlparse.format,
                  'foo', truncate_strings=0)


@pytest.mark.parametrize('sql', ['select verrrylongcolumn from foo',
                                 'select "verrrylongcolumn" from "foo"'])
def test_truncate_strings_doesnt_truncate_identifiers(sql):
    formatted = sqlparse.format(sql, truncate_strings=2)
    assert formatted == sql
