"""
Tests for the various modules in utils.
"""

#c Copyright 2008-2024, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


import os
import pathlib

from gavo.helpers import testhelpers

from gavo import base
from gavo import utils
from gavo.utils import algotricks
from gavo.utils import codetricks
from gavo.utils import stanxml
from gavo.utils import typeconversions


class TopoSortTest(testhelpers.VerboseTest):
	def testEmpty(self):
		self.assertEqual(algotricks.topoSort([]), [])

	def testSimpleGraph(self):
		self.assertEqual(algotricks.topoSort([(1,2), (2,3), (3,4)]), [1,2,3,4])

	def testComplexGraph(self):
		self.assertEqual(algotricks.topoSort([(1,2), (2,3), (1,3), (3,4),
			(1,4), (2,4)]), [1,2,3,4])

	def testCyclicGraph(self):
		self.assertRaisesWithMsg(ValueError, "Graph not acyclic, cycle: 2->1",
			algotricks.topoSort, ([(1,2), (2,1)],))


class PrefixTest(testhelpers.VerboseTest, metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, args):
		s1, s2, prefixLength = args
		self.assertEqual(utils.commonPrefixLength(s1, s2), prefixLength)
	
	samples = [
		("abc", "abd", 2),
		("abc", "a", 1),
		("abc", "", 0),
		("", "abc", 0),
		("a", "abc", 1),
		("z", "abc", 0),]


class IdManagerTest(testhelpers.VerboseTest):
	"""tests for working id manager.
	"""
	def setUp(self):
		self.im = utils.IdManagerMixin()

	def testNoDupe(self):
		testob = IdManagerTest
		self.assertEqual(self.im.makeIdFor(testob),
			utils.intToFunnyWord(id(testob)))
		self.assertRaises(ValueError,
			self.im.makeIdFor,
			testob)

	def testRetrieve(self):
		testob = "abc"
		theId = self.im.makeIdFor(testob)
		self.assertEqual(self.im.getIdFor(testob), theId)
	
	def testRefRes(self):
		testob = "abc"
		theId = self.im.makeIdFor(testob)
		self.assertEqual(self.im.getForId(theId), testob)
	
	def testUnknownOb(self):
		self.assertRaises(utils.NotFoundError, self.im.getIdFor, 1)

	def testUnknownId(self):
		self.assertRaises(utils.NotFoundError, self.im.getForId, "abc")

	def testSuggestion(self):
		testob = object()
		givenId = self.im.makeIdFor(testob, "ob1")
		self.assertEqual(givenId, "ob1")
		testob2 = object()
		id2 = self.im.makeIdFor(testob2, "ob1/")
		self.assertEqual(id2, "ob1-02")
		self.assertTrue(testob is self.im.getForId("ob1"))
		self.assertTrue(testob2 is self.im.getForId("ob1-02"))


class LoadModuleTest(testhelpers.VerboseTest):
	"""tests for cli's module loader.
	"""
	def testLoading(self):
		ob = utils.loadInternalObject("utils.codetricks", "loadPythonModule")
		self.assertTrue(hasattr(ob, "__call__"))
	
	def testNotLoading(self):
		self.assertRaises(ImportError, utils.loadInternalObject, "noexist", "u")
	
	def testBadName(self):
		self.assertRaises(AttributeError, utils.loadInternalObject,
			"utils.codetricks", "noexist")


class CachedGetterTest(testhelpers.VerboseTest):
	def testNormal(self):
		g = utils.CachedGetter(lambda c: [c], 3)
		self.assertEqual(g(), [3])
		g().append(4)
		self.assertEqual(g(), [3, 4])
	
	def testMortal(self):
		g = utils.CachedGetter(lambda c: [c], 3,
			isAlive=lambda l: len(l)<3)
		g().append(4)
		self.assertEqual(g(), [3,4])
		g().append(5)
		self.assertEqual(g(), [3])


class SimpleTextTest(testhelpers.VerboseTest):
	def testFeatures(self):
		with testhelpers.testFile("test.txt",
				r"""# Test File\
	this is stripped
An empty line is ignored

Contin\
uation lines \
# (a comment in between is ok)
  are concatenated
""")    as fName:
			with open(fName) as f:
				res = list(utils.iterSimpleText(f))

		self.assertEqual(res, [
			(2, "this is stripped"),
			(3, "An empty line is ignored"),
			(8, "Continuation lines are concatenated")])

	def testNoTrailingBackslash(self):
		with testhelpers.testFile("test.txt",
				"""No
non-finished\\
continuation\\""") as fName:
			with open(fName) as f:
				self.assertRaisesWithMsg(utils.SourceParseError,
					"At line 3: File ends with a backslash",
					lambda f: list(utils.iterSimpleText(f)),
					(f,))


class ToVOTableTypeTest(testhelpers.VerboseTest,
		metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		sqlType, voTableType = sample
		self.assertEqual(
			typeconversions.sqltypeToVOTable(sqlType),
			voTableType)
	
	samples = [
		("double precision", ('double', None, None)),
		("text", ('char', "*", None)),
		("char", ('char', '1', None)),
		("unicode", ('unicodeChar', "*", None)),
		("double precision[2]", ('double', '2', None)),
# 5
		("timestamp", ("char", "19", "timestamp")),
		("spoint", ("double", "2", "point")),
		("int4range", ("int", "2", "interval")),
		("timestamp[5]", ("char", "19x5", "timestamp")),
		("scircle[5]", ("double", "3x5", "circle")),
# 10
		("char[1]", ("char", "1", None)),
		("char[12][]", ("char", "12x*", None)),
		# we should probably flat this as invalid
		("char(12)[*]", ("char", "12x*", None)),
		("varchar(*)", ("char", "*", None)),
		# we probably shouldn't let the following parse
		("varchar[13][15](*)", ("char", "13x15x*", None)),
# 15
		("varchar(*)", ("char", "*", None)),
	]


class ToVOTableErrorTest(testhelpers.VerboseTest,
		metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		sqlType, message = sample
		self.assertRaisesWithMsg(Exception,
			message,
			typeconversions.sqltypeToVOTable,
			(sqlType,))
	
	samples = [
		("varchar[1", "No VOTable type for varchar[1"),
		("vanqual", "No VOTable type for vanqual"),
		("char[][]",
			"Arrays may only have variable length in the last dimension"),
	]


class NoModuleAliasingTest(testhelpers.VerboseTest):
	def testAliasing(self):
		dn = base.getConfig("inputsDir")
		with testhelpers.testFile("klotz.py", "sentinel = 1",
				inDir=os.path.join(dn, "mod1")) as modsrc1:
			with testhelpers.testFile("klotz.py", "sentinel = 2",
					inDir=os.path.join(dn, "mod2")) as modsrc2:
				mod1, spec1 = codetricks.loadPythonModule(modsrc1[:-3])
				mod2, spec2 = codetricks.loadPythonModule(modsrc2[:-3])
				self.assertEqual(mod1.sentinel, 1)
				self.assertEqual(mod2.sentinel, 2)
				try:
					import klotz  #noflake: supposed to fail
				except ImportError:
					# this must fail in order to keep "local" modules from polluting
					# global imports.
					pass
				else:
					self.fail("loadPythonModule messed up sys.modules or sys.path.")


class StanXMLTest(testhelpers.VerboseTest):
	class Model(object):
		class MEl(stanxml.Element):
			_local = True
		class Root(MEl):
			_childSequence = ["Child", "Nilble"]
		class Child(MEl):
			_childSequence = ["Foo", None]
		class Other(MEl):
			pass
		class Nilble(stanxml.NillableMixin, MEl):
			_a_restatt = None

	def testNoTextContent(self):
		M = self.Model
		self.assertRaises(stanxml.ChildNotAllowed, lambda:M.Root["abc"])
	
	def testTextContent(self):
		M = self.Model
		data = M.Root[M.Child["a\xA0bc"]]
		self.assertEqual(data.render(), b'<Root><Child>a\xc2\xa0bc</Child></Root>')

	def testRetrieveText(self):
		M = self.Model
		data = M.Other["thrown away", M.Other["mixed"], " remaining "]
		self.assertEqual(data.text_, " remaining ")

	def testNillableNil(self):
		M = self.Model
		rendered = M.Root[M.Nilble()].render()
		self.assertTrue(b'xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"'
			in rendered)
		self.assertTrue(b'Nilble xsi:nil="true"' in rendered)
	
	def testNillableNonNil(self):
		M = self.Model
		rendered = M.Root[M.Nilble["Value"]].render()
		self.assertTrue(b"<Nilble>Value</Nilble>" in rendered)
		self.assertFalse(b'xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"'
			in rendered)
	
	def testNillableAttribute(self):
		M = self.Model
		rendered = M.Root[M.Nilble(restatt="x")].render()
		self.assertTrue(b'<Nilble restatt="x" xsi:nil="true"></Nilble>' in rendered)

	def testSerialisingException(self):
		try:
			_ = 1/0
		except Exception as raised:
			e = raised

		M = self.Model
		rendered = M.Root[M.Child[e]].render()
		self.assertTrue(
			b"<Child>EXCEPTION:\nZeroDivisionError: division by zero"
			in rendered)

	def testLocalAttribute(self):
		M = self.Model
		doc = M.Root[M.Child["x"]]
		doc.addAttribute("xmlns:foo", "urn:foo")
		doc.addAttribute("xmlns:bar", "urn:bar")

		self.assertEqual(
			doc.render(),
			b'<Root xmlns:bar="urn:bar" xmlns:foo="urn:foo"><Child>x</Child></Root>')


class StanXMLNamespaceTest(testhelpers.VerboseTest):

	stanxml.registerPrefix("ns1", "http://bar.com", None)
	stanxml.registerPrefix("ns0", "http://foo.com", None)
	stanxml.registerPrefix("foo", "http://bori.ng", "http://schema.is.here")

	class E(object):
		class LocalElement(stanxml.Element):
			_prefix = "ns1"
			_local = _mayBeEmpty = True
		class A(LocalElement):
			_a_x = None
		class B(LocalElement):
			_a_y = None
		class NSElement(stanxml.Element):
			_prefix = "ns0"
		class C(NSElement):
			_a_z = "ab"
		class D(NSElement):
			_a_u = "x"
			_name_a_u = "foo:u"
			_additionalPrefixes = frozenset(["foo"])

	def testTraversal(self):
		tree = self.E.A[self.E.B, self.E.B, self.E.A]
		def record(node, content, attrDict, childIter):
			return (node.name_,
				[c.apply(record) for c in childIter])
		self.assertEqual(tree.apply(record),
			('A', [('B', []), ('B', []), ('A', [])]))
	
	def testSimpleRender(self):
		tree = self.E.A[self.E.B, self.E.B, self.E.A]
		self.assertEqual(testhelpers.cleanXML(tree.render()),
			'<A><B/><B/><A/></A>')

	def testRenderWithText(self):
		E = self.E
		tree = E.A[E.C["arg"], E.C(z="c")[E.B["muss"], E.A]]
		self.assertEqual(tree.render(),
			b'<A xmlns:ns0="http://foo.com" xmlns:ns1="http://bar.com"><ns0:C z="ab">arg</ns0:C>'
				b'<ns0:C z="c"><B>muss</B><A/></ns0:C></A>')

	def testAdditionalPrefixes(self):
		tree = self.E.C[self.E.D["xy"]]
		self.assertEqual(tree.render(includeSchemaLocation=False),
			b'<ns0:C xmlns:foo="http://bori.ng" xmlns:ns0="http://foo.com" z="ab"><ns0:D foo:u="x">xy</ns0:D></ns0:C>')

	def testSchemaLocation(self):
		tree = self.E.D["xy"]
		self.assertEqual(tree.render(),
			b'<ns0:D foo:u="x" xmlns:foo="http://bori.ng" xmlns:ns0="http://'
			b'foo.com" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" '
			b'xsi:schemaLocation="http://bori.ng http://schema.is.here">xy</ns0:D>')

	def testEmptyPrefix(self):
		tree = self.E.C["bar"]
		self.assertEqual(tree.render(prefixForEmpty="ns0"),
			b'<C xmlns:ns0="http://foo.com" xmlns="http://foo.com" z="ab">bar</C>')


class IAUDesignationTest(testhelpers.VerboseTest,
		metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		args, expected = sample
		res = utils.makeIAUId(*args)
		self.assertEqual(res, expected)
	
	samples = [
		(("TJ", 0, 0), "TJ000000+000000"),
		(("TJ", 0, 0, 1), "TJ000000.0+000000"),
		(("TJ", 0, 0, 0, 1), "TJ000000+000000.0"),
		(("BB", 34.13333, 23.24722, 1, 0), "BB021631.9+231449"),
		(("BB", 34.13333, -23.24722, 1, 0), "BB021631.9-231449"),
		(("BB", 34.13333, -23.24722, 0, 0), "BB021631-231449"),
		(("BB", 34.13333, -23.24722, 0, 1), "BB021631-231449.9"),
	]


pP = pathlib.Path

class PathOpsTest(testhelpers.VerboseTest):
	def testRelPathBasicStr(self):
		self.assertEqual(
			utils.getRelativePath("/foo/bar/baz", "/foo"),
			"bar/baz")

	def testRelPathBasicPath(self):
		self.assertEqual(
			utils.getRelativePath(
				pP("/foo/bar/baz"), pP("/foo")),
			pP("bar/baz"))

	def testRelPathTrailingSlashStr(self):
		self.assertEqual(
			utils.getRelativePath("/foo/bar/baz", "/foo/"),
			"bar/baz")

	def testRelPathTrailingSlashPath(self):
		self.assertEqual(
			utils.getRelativePath(
				pP("/foo/bar/baz"), pP("/foo/")),
			pP("bar/baz"))

	def testNoRelPathStr(self):
		self.assertRaises(ValueError,
			utils.getRelativePath,
			"/foo/bar/baz", "/bar")

	def testNoRelPathPath(self):
		self.assertRaises(ValueError,
			utils.getRelativePath,
			pP("/foo/bar/baz"), pP("/bar"))

	def testIlliberalChars(self):
		self.assertRaises(ValueError,
			utils.getRelativePath,
			pP("/foo/bar/baz+quux"), pP("/bar"))

	def testLiberalChars(self):
		self.assertRaises(ValueError,
			utils.getRelativePath,
			pP("/foo/bar/baz+quux"), pP("/bar"),
			liberalChars=True)

	def testIdenticalStr(self):
		self.assertEqual(
			utils.getRelativePath(
				"/foo/bar", "/foo/bar"),
			"")

	def testIdenticalPath(self):
		self.assertEqual(
			utils.getRelativePath(
				pP('/home/msdemlei/_gavo_test/inputs'),
				pP('/home/msdemlei/_gavo_test/inputs')),
			pP("."))


if __name__=="__main__":
	testhelpers.main(CachedGetterTest)
