"""
Tests for the formal forms code.  Most of this is taken from formal's
source tree and thus is covered by the liberal license imp/formal/LICENSE.
"""

#c Copyright 2008-2024, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


from datetime import date, time
import re
import unittest

from gavo.helpers import testhelpers

from gavo import formal
from gavo.formal import converters, validation, types, util



class TestConverters(unittest.TestCase):

    def test_null(self):
        c = converters.NullConverter(None)
        self.assertEqual(c.fromType('foo'), 'foo')
        self.assertEqual(c.toType('foo'), 'foo')

    def test_integerToString(self):
        c = converters.IntegerToStringConverter(None)
        self.assertEqual(c.fromType(None), None)
        self.assertEqual(c.fromType(1), '1')
        self.assertEqual(c.fromType(0), '0')
        self.assertEqual(c.fromType(-1), '-1')
        self.assertEqual(c.toType(''), None)
        self.assertEqual(c.toType(' '), None)
        self.assertEqual(c.toType('1'), 1)
        self.assertEqual(c.toType('0'), 0)
        self.assertEqual(c.toType('-1'), -1)
        self.assertRaises(validation.FieldValidationError, c.toType, '1.1')
        self.assertRaises(validation.FieldValidationError, c.toType, 'foo')

    def test_floatToString(self):
        c = converters.FloatToStringConverter(None)
        self.assertEqual(c.fromType(None), None)
        self.assertEqual(c.fromType(1), '1')
        self.assertEqual(c.fromType(0), '0')
        self.assertEqual(c.fromType(-1), '-1')
        self.assertEqual(c.fromType(1.5), '1.5')
        self.assertEqual(c.toType(''), None)
        self.assertEqual(c.toType(' '), None)
        self.assertEqual(c.toType('1'), 1)
        self.assertEqual(c.toType('0'), 0)
        self.assertEqual(c.toType('-1'), -1)
        self.assertEqual(c.toType('-1.5'), -1.5)
        self.assertRaises(validation.FieldValidationError, c.toType, 'foo')

    def test_decimalToString(self):
        from decimal import Decimal
        c = converters.DecimalToStringConverter(None)
        self.assertEqual(c.fromType(None), None)
        self.assertEqual(c.fromType(Decimal("1")), '1')
        self.assertEqual(c.fromType(Decimal("0")), '0')
        self.assertEqual(c.fromType(Decimal("-1")), '-1')
        self.assertEqual(c.fromType(Decimal("1.5")), '1.5')
        self.assertEqual(c.toType(''), None)
        self.assertEqual(c.toType(' '), None)
        self.assertEqual(c.toType('1'), Decimal("1"))
        self.assertEqual(c.toType('0'), Decimal("0"))
        self.assertEqual(c.toType('-1'), Decimal("-1"))
        self.assertEqual(c.toType('-1.5'), Decimal("-1.5"))
        self.assertEqual(c.toType('-1.863496'), Decimal("-1.863496"))
        self.assertRaises(validation.FieldValidationError, c.toType, 'foo')

    def test_booleanToString(self):
        c = converters.BooleanToStringConverter(None)
        self.assertEqual(c.fromType(False), 'False')
        self.assertEqual(c.fromType(True), 'True')
        self.assertEqual(c.fromType(None), None)
        self.assertEqual(c.toType('False'), False)
        self.assertEqual(c.toType('True'), True)
        self.assertEqual(c.toType(''), None)
        self.assertEqual(c.toType('  '), None)
        self.assertRaises(validation.FieldValidationError, c.toType, 'foo')

    def test_dateToString(self):
        c = converters.DateToStringConverter(None)
        self.assertEqual(c.fromType(date(2005, 5, 6)), '2005-05-06')
        self.assertEqual(c.fromType(date(2005, 1, 1)), '2005-01-01')
        self.assertEqual(c.toType(''), None)
        self.assertEqual(c.toType(' '), None)
        self.assertEqual(c.toType('2005-05-06'), date(2005, 5, 6))
        self.assertEqual(c.toType('2005-01-01'), date(2005, 1, 1))
        self.assertRaises(validation.FieldValidationError, c.toType, 'foo')
        self.assertRaises(validation.FieldValidationError, c.toType, '2005')
        self.assertRaises(validation.FieldValidationError, c.toType, '01/01/2005')
        self.assertRaises(validation.FieldValidationError, c.toType, '01-01-2005')

    def test_timeToString(self):
        c = converters.TimeToStringConverter(None)
        self.assertEqual(c.fromType(time(12, 56)), '12:56:00')
        self.assertEqual(c.fromType(time(10, 12, 24)), '10:12:24')
        self.assertEqual(c.toType(''), None)
        self.assertEqual(c.toType(' '), None)
        self.assertEqual(c.toType('12:56'), time(12, 56))
        self.assertEqual(c.toType('12:56:00'), time(12, 56))
        self.assertEqual(c.toType('10:12:24'), time(10, 12, 24))
        self.assertRaises(validation.FieldValidationError, c.toType, 'foo')
        self.assertRaises(validation.FieldValidationError, c.toType, '10')
        self.assertRaises(validation.FieldValidationError, c.toType, '10-12')

    def test_dateToTuple(self):
        c = converters.DateToDateTupleConverter(None)
        self.assertEqual(c.fromType(date(2005, 5, 6)), (2005, 5, 6))
        self.assertEqual(c.fromType(date(2005, 1, 1)), (2005, 1, 1))
        self.assertEqual(c.toType((2005, 5, 6)), date(2005, 5, 6))
        self.assertEqual(c.toType((2005, 1, 1)), date(2005, 1, 1))
        self.assertRaises(validation.FieldValidationError, c.toType, ('foo'))
        self.assertRaises(validation.FieldValidationError, c.toType, (2005,))
        self.assertRaises(validation.FieldValidationError, c.toType, (2005,10))
        self.assertRaises(validation.FieldValidationError, c.toType, (1, 1, 2005))

class TestForm(unittest.TestCase):

    def test_fieldName(self):
        form = formal.Form()
        form.addField('foo', formal.String())
        self.assertRaises(ValueError, form.addField, 'spaceAtTheEnd ', formal.String())
        self.assertRaises(ValueError, form.addField, 'got a space in it', formal.String())


class TestValidators(unittest.TestCase):

    def testHasValidator(self):
        t = formal.String(validators=[validation.LengthValidator(max=10)])
        self.assertEqual(t.hasValidator(validation.LengthValidator), True)

    def testRequired(self):
        t = formal.String(required=True)
        self.assertEqual(t.hasValidator(validation.RequiredValidator), True)
        self.assertEqual(t.required, True)


class TestCreation(unittest.TestCase):

    def test_immutablility(self):
        self.assertEqual(formal.String().immutable, False)
        self.assertEqual(formal.String(immutable=False).immutable, False)
        self.assertEqual(formal.String(immutable=True).immutable, True)

    def test_immutablilityOverride(self):
        class String(formal.String):
            immutable = True
        self.assertEqual(String().immutable, True)
        self.assertEqual(String(immutable=False).immutable, False)
        self.assertEqual(String(immutable=True).immutable, True)


class TestValidate(unittest.TestCase):

    def testString(self):
        self.assertEqual(formal.String().validate(None), None)
        self.assertEqual(formal.String().validate(''), None)
        self.assertEqual(formal.String().validate(' '), ' ')
        self.assertEqual(formal.String().validate('foo'), 'foo')
        self.assertEqual(formal.String().validate('foo'), 'foo')
        self.assertEqual(formal.String(strip=True).validate(' '), None)
        self.assertEqual(formal.String(strip=True).validate(' foo '), 'foo')
        self.assertEqual(formal.String(missing='bar').validate('foo'), 'foo')
        self.assertEqual(formal.String(missing='bar').validate(''), 'bar')
        self.assertEqual(formal.String(strip=True, missing='').validate(' '), '')
        self.assertEqual(formal.String(missing='foo').validate('bar'), 'bar')
        self.assertRaises(formal.FieldValidationError, formal.String(required=True).validate, '')
        self.assertRaises(formal.FieldValidationError, formal.String(required=True).validate, None)

    def testInteger(self):
        self.assertEqual(formal.Integer().validate(None), None)
        self.assertEqual(formal.Integer().validate(0), 0)
        self.assertEqual(formal.Integer().validate(1), 1)
        self.assertEqual(formal.Integer().validate(-1), -1)
        self.assertEqual(formal.Integer(missing=1).validate(None), 1)
        self.assertEqual(formal.Integer(missing=1).validate(2), 2)
        self.assertRaises(formal.FieldValidationError, formal.Integer(required=True).validate, None)

    def testFloat(self):
        self.assertEqual(formal.Float().validate(None), None)
        self.assertEqual(formal.Float().validate(0), 0.0)
        self.assertEqual(formal.Float().validate(0.0), 0.0)
        self.assertEqual(formal.Float().validate(.1), 0.1)
        self.assertEqual(formal.Float().validate(1), 1.0)
        self.assertEqual(formal.Float().validate(-1), -1.0)
        self.assertEqual(formal.Float().validate(-1.86), -1.86)
        self.assertEqual(formal.Float(missing=1.0).validate(None), 1.0)
        self.assertEqual(formal.Float(missing=1.0).validate(2.0), 2.0)
        self.assertRaises(formal.FieldValidationError, formal.Float(required=True).validate, None)

    def testDecimal(self):
        from decimal import Decimal
        self.assertEqual(formal.Decimal().validate(None), None)
        self.assertEqual(formal.Decimal().validate(Decimal('0')), Decimal('0'))
        self.assertEqual(formal.Decimal().validate(Decimal('0.0')), Decimal('0.0'))
        self.assertEqual(formal.Decimal().validate(Decimal('.1')), Decimal('0.1'))
        self.assertEqual(formal.Decimal().validate(Decimal('1')), Decimal('1'))
        self.assertEqual(formal.Decimal().validate(Decimal('-1')), Decimal('-1'))
        self.assertEqual(formal.Decimal().validate(Decimal('-1.86')),
                Decimal('-1.86'))
        self.assertEqual(formal.Decimal(missing=Decimal("1.0")).validate(None),
                Decimal("1.0"))
        self.assertEqual(formal.Decimal(missing=Decimal("1.0")).validate(Decimal("2.0")),
                Decimal("2.0"))
        self.assertRaises(formal.FieldValidationError, formal.Decimal(required=True).validate, None)

    def testBoolean(self):
        self.assertEqual(formal.Boolean().validate(None), None)
        self.assertEqual(formal.Boolean().validate(True), True)
        self.assertEqual(formal.Boolean().validate(False), False)
        self.assertEqual(formal.Boolean(missing=True).validate(None), True)
        self.assertEqual(formal.Boolean(missing=True).validate(False), False)

    def testDate(self):
        self.assertEqual(formal.Date().validate(None), None)
        self.assertEqual(formal.Date().validate(date(2005,1,1)), date(2005,1,1))
        self.assertEqual(formal.Date(missing=date(2005,1,2)).validate(None), date(2005,1,2))
        self.assertEqual(formal.Date(missing=date(2005,1,2)).validate(date(2005,1,1)), date(2005,1,1))
        self.assertRaises(formal.FieldValidationError, formal.Date(required=True).validate, None)

    def testTime(self):
        self.assertEqual(formal.Time().validate(None), None)
        self.assertEqual(formal.Time().validate(time(12,30,30)), time(12,30,30))
        self.assertEqual(formal.Time(missing=time(12,30,30)).validate(None), time(12,30,30))
        self.assertEqual(formal.Time(missing=time(12,30,30)).validate(time(12,30,31)), time(12,30,31))
        self.assertRaises(formal.FieldValidationError, formal.Time(required=True).validate, None)

    def test_sequence(self):
        self.assertEqual(formal.Sequence(formal.String()).validate(None), None)
        self.assertEqual(formal.Sequence(formal.String()).validate(['foo']), ['foo'])
        self.assertEqual(formal.Sequence(formal.String(), missing=['foo']).validate(None), ['foo'])
        self.assertEqual(formal.Sequence(formal.String(), missing=['foo']).validate(['bar']), ['bar'])
        self.assertRaises(formal.FieldValidationError, formal.Sequence(formal.String(), required=True).validate, None)
        self.assertRaises(formal.FieldValidationError, formal.Sequence(formal.String(), required=True).validate, [])


class TestUtil(unittest.TestCase):

    def test_validIdentifier(self):
        self.assertEqual(util.validIdentifier('foo'), True)
        self.assertEqual(util.validIdentifier('_foo'), True)
        self.assertEqual(util.validIdentifier('_foo_'), True)
        self.assertEqual(util.validIdentifier('foo2'), True)
        self.assertEqual(util.validIdentifier('Foo'), True)
        self.assertEqual(util.validIdentifier(' foo'), False)
        self.assertEqual(util.validIdentifier('foo '), False)
        self.assertEqual(util.validIdentifier('9'), False)


class TestRequired(unittest.TestCase):

    def test_required(self):
        v = validation.RequiredValidator()
        v.validate(types.String(), 'bar')
        self.assertRaises(validation.FieldRequiredError, v.validate, types.String(), None)


class TestRange(unittest.TestCase):

    def test_range(self):
        self.assertRaises(AssertionError, validation.RangeValidator)
        v = validation.RangeValidator(min=5, max=10)
        v.validate(types.Integer(), None)
        v.validate(types.Integer(), 5)
        v.validate(types.Integer(), 7.5)
        v.validate(types.Integer(), 10)
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), 0)
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), 4)
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), -5)
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), 11)

    def test_rangeMin(self):
        v = validation.RangeValidator(min=5)
        v.validate(types.Integer(), None)
        v.validate(types.Integer(), 5)
        v.validate(types.Integer(), 10)
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), 0)
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), 4)
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), -5)

    def test_rangeMax(self):
        v = validation.RangeValidator(max=5)
        v.validate(types.Integer(), None)
        v.validate(types.Integer(), -5)
        v.validate(types.Integer(), 0)
        v.validate(types.Integer(), 5)
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), 6)
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), 10)


class TestLength(unittest.TestCase):

    def test_length(self):
        self.assertRaises(AssertionError, validation.LengthValidator)
        v = validation.LengthValidator(min=5, max=10)
        v.validate(types.String(), None)
        v.validate(types.String(), '12345')
        v.validate(types.String(), '1234567')
        v.validate(types.String(), '1234567890')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), '')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), '1234')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), '12345678901')

    def test_lengthMin(self):
        v = validation.LengthValidator(min=5)
        v.validate(types.String(), None)
        v.validate(types.String(), '12345')
        v.validate(types.String(), '1234567890')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), '')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), '1234')

    def test_lengthMax(self):
        v = validation.LengthValidator(max=5)
        v.validate(types.String(), None)
        v.validate(types.String(), '1')
        v.validate(types.String(), '12345')
        v.validate(types.String(), '123')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), '123456')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), '1234567890')


class TestPattern(unittest.TestCase):

    def test_pattern(self):
        v = validation.PatternValidator('^[0-9]{3,5}$')
        v.validate(types.String(), None)
        v.validate(types.String(), '123')
        v.validate(types.String(), '12345')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), ' 123')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), '1')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), 'foo')

    def test_regex(self):
        v = validation.PatternValidator(re.compile('^[0-9]{3,5}$'))
        v.validate(types.String(), None)
        v.validate(types.String(), '123')
        v.validate(types.String(), '12345')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), ' 123')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), '1')
        self.assertRaises(validation.FieldValidationError, v.validate, types.String(), 'foo')


class TestRender(unittest.TestCase):
	def test_hiddenboolean(self):
		w = formal.Hidden(formal.Boolean("t"))
		# can't render this because there's unfilled slots; write tests
		# so that they work with unrendered things.
		tag = w.render(None, "t", {"t": None}, None)
		self.assertEqual(tag.attributes["value"], "")


if __name__=="__main__":
	testhelpers.main(TestConverters)
