"""
Tests for ADQL parsing and reasoning about query results.
"""

#c Copyright 2008-2024, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


import datetime
import functools
import re
import unittest

from gavo.helpers import testhelpers

from gavo import adql
from gavo import base
from gavo import stc
from gavo import rsc
from gavo import rscdef
from gavo import rscdesc #noflake: for registration
from gavo import utils
from gavo.adql import annotations
from gavo.adql import fieldinfos
from gavo.adql import morphpg
from gavo.adql import nodes
from gavo.adql import ufunctions  #noflake: for registration
from gavo.protocols import adqlglue
from gavo.protocols import tap
from gavo.stc import tapstc
from gavo.utils import pgsphere
from gavo.utils import parsetricks

import tresc

MS = base.makeStruct

class Error(Exception):
	pass


# The resources below are used elsewhere (e.g., taptest).
class _ADQLQuerier(testhelpers.TestResource):
	resources = [("conn", tresc.dbConnection)]

	def make(self, deps):
		class ADQLQuerier(base.UnmanagedQuerier):
			def queryADQL(self, query, timeout=10, maxrec=20000):
				qtable = adqlglue.runTAPQuery(
					query, timeout, self.connection, [], maxrec, False)
				res = rsc.InMemoryTable(
					qtable.tableDef,
					rows=list(qtable))
				res.meta_ = qtable.meta_
				qtable.cleanup()
				return res

		return ADQLQuerier(deps["conn"])
	
	def cleanup(self, deps):
		deps["conn"].rollback()
	

adqlQuerier = _ADQLQuerier()


class _ADQLTestTable(testhelpers.TestResource):
	resources = [("conn", tresc.dbConnection)]

	def make(self, deps):
		self.rd = testhelpers.getTestRD()
		ds = rsc.makeData(self.rd.getById("ADQLTest"),
				connection=deps["conn"])
		tap.publishToTAP(self.rd, deps["conn"])
		return ds
	
	def clean(self, ds):
		conn = list(ds.tables.values())[0].connection
		conn.rollback()
		ds.dropTables(rsc.parseNonValidating)
		conn.commit()
adqlTestTable = _ADQLTestTable()



class MatchLimitTest(testhelpers.VerboseTest, metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		topClause, maxrec, hardLimit, expectedInQuery, expectedLimit = sample
		tree = parseWithArtificialTable(
			"SELECT %s * FROM spatial"%topClause)
		overflowLimit = adqlglue._updateMatchLimits(tree, maxrec, hardLimit)
		self.assertEqual(tree.setLimit, expectedInQuery)
		self.assertEqual(overflowLimit, expectedLimit)
	
	samples = [
		("", 2000, 200000, 2000, 2000),
		("TOP 10", 2000, 200000, 10, 2000),
		("TOP 10000", 2000, 200000, 2001, 2001),
		("TOP 30000", None, 200000, 20001, 20001),
		("TOP 30000", None, None, 20001, 20001),
# 05
		("TOP 10000", 99999999999, None, 10000, 20000000),
		("TOP 10000", 99999999999, None, 10000, 20000000),
		("", 99999999999, None, 20000000, 20000000),
		("", 99999999999, 50000, 50000, 50000),
		("TOP 1", 1, 50000, 1, 2),
	]


class _SymbolsParseTestBase(testhelpers.VerboseTest):
	def setUp(self):
		self.symbols, _ = adql.getRawGrammar()

	def _assertParses(self, symbol, literal):
		try:
			(self.symbols[symbol]+parsetricks.StringEnd()).parseString(literal)
		except adql.ParseException:
			raise AssertionError("%s doesn't parse %s but should."%(symbol,
				repr(literal)))

	def _assertDoesntParse(self, symbol, literal):
		try:
			(self.symbols[symbol]+parsetricks.StringEnd()).parseString(literal)
		except (adql.ParseException, adql.ParseSyntaxException):
			pass
		else:
			raise AssertionError("%s parses %s but shouldn't."%(symbol,
				repr(literal)))


class _GoodExamplesBase(_SymbolsParseTestBase, metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		self._assertParses(*sample)


class _BadExamplesBase(_SymbolsParseTestBase, metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		self._assertDoesntParse(*sample)


class MiscBadSymbolsTest(_BadExamplesBase):
	samples = [
		("dateValueExpression", "TIMESTAMP('1992-03-01', a)"),
		("dateValueExpression", "TIMESTAMP(b, a)"),
		("castSpecification", "cast(char)"),
		("castSpecification", "cast(x as y)"),
		("castSpecification", "cast(x as national bank)"),
# 5
		("numericValueExpression", "BITWISE_NOT()"),
		("numericValueExpression", "BITWISE_AND(a)"),
		("numericValueExpression", "BITWISE_AND(a, b, c)"),
		("joinedTable", "foo natural join (t1, t2)"),
		("joinedTable", "(t1 join t2) using (foo)"),
# 10
		("joinedTable", "(t1 join t2) on (t1.foo=t2.bar)"),
		("possiblyAliasedTable", "t1 tablesample(a)"),
		("numericValueFunction", "arr_map(POWER, a)"),
		("caseExpression", "case x 'a' then 8 when 'b' then 9 end"),
		("caseExpression", "case x when 'a'=x then 8 end"),
# 15
		("caseExpression", "case when 'a' then 8 end"),
	]


class MiscGoodSymbolsTest(_GoodExamplesBase):
	samples = [
		("searchCondition", "5+9<'b' || 'foo'"),
		("delimitedIdentifier", '"a"'),
		("delimitedIdentifier", '"a""b"'),
		("comparisonPredicate", '"ja ja"<"Umph"'),
		("comparisonPredicate", "a<b"),
# 5
		("comparisonPredicate", "'a'<'b'"),
		("comparisonPredicate", "'a'<'b' || 'foo'"),
		("comparisonPredicate", "5+9<'b' || 'foo'"),
		("characterStringLiteral", "'abc'"),
		("characterValueExpression", "'abc' || 'def'"),
# 10
		("stringValueExpression", "'abc' || 'def' || '78%%'"),
		("dateValueExpression", "TIMESTAMP('1992-03-01' || 'T12:33')"),
		("dateValueExpression", "TIMESTAMP(b)"),
		("castSpecification", "CAST(x+23.0 AS INTEGER)"),
		("castSpecification", "CAST('230' AS BIGINT)"),
# 15
		("castSpecification", "CAST(230 AS NATIONAL   CHAR ( 10 ))"),
		("castSpecification", 'CAST("My stupid col" || \'x\' AS CHAR(1230))'),
		("castSpecification", "CAST(230  AS CHAR)"),
		("castSpecification", "CAST(230+PI() AS   REAL)"),
		("castSpecification", "CAST(SQRT(honk) AS DOUBLE PRECISION)"),
# 20
		("castSpecification", "CAST('2017-02-30' AS TIMESTAMP)"),
		("castSpecification", "CAST(NULL AS TIMESTAMP)"),
		("numericValueExpression", "BitWISE_NOT(x)"),
		("numericValueExpression", "BITWISE_and(x, 2)"),
		("numericValueExpression", "BITWISE_OR(x, y+2)"),
# 25
		("numericValueExpression", "BITWISE_XOR(5, 2)"),
		("withSpecification", "WITH foobar as (select a,b,c from x),"
			" knatter as (select cos(d)+13 as foo from y)"),
		("valueExpressionPrimary", "arr[15]"),
		("valueExpressionPrimary", "arr[ROUND(x/10)+3]"),
		("derivedColumn", "98x"),
#30
		("derivedColumn", "(A+B)X"),
		("possiblyAliasedTable", '"gnott" as g tablesample (0.1)'),
		("possiblyAliasedTable", '"gnott" tablesample(1e-7)'),
		("setGeneratingFunction", "generate_series ( 3 , 4 )"),
		("castSpecification", "CAST(foo as char(*))"),
#35
		("castSpecification", "CAST(foo as varchar(4))"),
		("coalesceExpression", "COALESCE(x)"),
		("coalesceExpression", "COALESCE(round(x/10+3), NULL, 3)"),
		("numericPrimary", "PI()"),
		("castSpecification", "CAST(NULL AS DOUBLE PRECISION[])"),
#40
		("castSpecification", "CAST(NULL AS CHAR(39)[30])"),
		("numericValueFunction", "arr_map(sin(x), a)"),
		("numericValueFunction", "arr_map(4*power(x, 3)-3, a)"),
		("orderByClause", "order by z"),
		("orderByClause", "order by 23"),
#45
		("orderByClause", "order by z desc"),
		("orderByClause", "order by z desc, x asc"),
		("orderByClause", "order by z-1 desc, log10(x) asc"),
		("groupByClause", "group by z"),
		("groupByClause", "group by z, s"),
#50
		("groupByClause", "group by z[3], s+z"),
		("havingClause", "having x=z AND 7<u"),
		("caseExpression", "case x when 'a' then 8 when 'b' then 9 end"),
		("caseExpression", "case x when 'a' then 8*9 else 8*9 end"),
		("caseExpression", "case when x='a' then power(y,2)"
			" when 4<y then y-2 else 'u' end"),
# 55
	]


class GoodBooleanTermsTest(_GoodExamplesBase):
	samples = [
		("searchCondition", "z BETWEEN 8 AND 9"),
		("searchCondition", "z BETWEEN 'a' AND 'b'"),
		("searchCondition", "z BEtWEEN x+8 AnD x*8"),
		("searchCondition", "z NOT BETWEEN x+8 AND x*8"),
		("searchCondition", "z iN (a)"),
		("searchCondition", "z NoT In (a)"),
		("searchCondition", "z NOT IN (a, 4, 'xy')"),
		("searchCondition", "z IN (select x from foo)"),
		("searchCondition", "u LIKE '%'"),
		("searchCondition", "u NoT LiKE '%'"),
		("searchCondition", "u ILIKE '%'"),
		("searchCondition", "u Not ILIKE '%'"),
		("searchCondition", "u || 'foo' NOT LIKE '%'"),
		("searchCondition", "u NOT LIKE '%' || 'xy'"),
		("searchCondition", "k IS NULL"),
		("searchCondition", "k IS NOT NULL"),
	]


class BadGeometriesTest(_BadExamplesBase):
	samples = [
		("point", "POINT(x,y,z)"),
		("circle", "circle('ICRS', x)"),
		("circle", "circle(5, y)"),
		("geometryExpression", "circle('ICRS', x)"),
		("polygon", "POLYGON(2, 3)"),
#5
		("polygon", "POLYGON(2, 4, 3)"),
		("polygon", "POLYGON('', 2, 4, 3)"),
		("polygon", "POLYGON(POINT(2, 4), POINT(3, 5), 3, 6)"),
		("region", "REGION(23, 'CIRCLE ICRS 2 3 4)"),
		("booleanTerm", "Point('fk5',2,3)"),
#10
		("booleanTerm", "CIRCLE('fk5', 2, 3)=x"),
		("booleanTerm", "POLYGON('fk5', 2, 3, 3, 0, 23, 0, 45)=x"),
		("booleanTerm", "CENTROID(3)=x"),
		("booleanTerm", "CENTROID(COUNT(*))=x"),
		("nonPredicateGeometryFunction", "DISTANCE()"),
#15
		("nonPredicateGeometryFunction", "DISTANCE(a, b, c)"),
		("nonPredicateGeometryFunction", "DISTANCE(POINT(a,b), c, d)"),
		("nonPredicateGeometryFunction", "DISTANCE(POINT(a,b), POINT(c, d), e)"),
		("nonPredicateGeometryFunction", "DISTANCE(a, CENTROID(CIRCLE(b, c, d)))"),
	]


class GoodGeometriesTest(_GoodExamplesBase):
	samples = [
		("point", "pOint('ICRS', x,y)"),
		("point", "point(NULL, x,y)"),
		("point", "POINT(x,y)"),
		("circle", "circle('ICRS', x,y, r)"),
		("circle", "CIRCLE(NULL, 1,2, 4)"),
#5
		("circle", "CIRCLE(1,2, 4)"),
		("circle", "CIRCLE('', c, r)"),
		("circle", "CIRCLE(c, 5)"),
		("polygon", "POLYGON(NULL, 1, 2, 4, 3, 4, 4)"),
		("polygon", "POLYGON('', 1, 2, 4, 3, 4, 4)"),
#10
		("polygon", "POLYGON(1, 2, 4, 3, 4, 4)"),
		("polygon", "POLYGON(a, b, c)"),
		("polygon", "POLYGON(POINT(2,3), b, c)"),
		("polygon", "POLYGON(a, b, POINT(2,3))"),
		("box", "BOX(1, 2, 0.2, 0.1)"),
#15
		("box", "BOX('GALACTIC', 1, 2, 0.2, 0.1)"),
		("region", "REGION('CIRCLE ICRS 2 3 4)')"),
		("geometryExpression", "CIRCLE('ICRS', 1,2, 4)"),
		("predicateGeometryFunction",
			"Contains(pOint('ICRS', x,y),CIRCLE('ICRS', 1,2, 4))"),
		("booleanTerm", "Point(NULL, 2, 3)=x"),
#20
		("booleanTerm", "Point('fk5', 2, 3)=x"),
		("booleanTerm", "CIRCLE('fk5', 2, 3, 3)=x"),
		("booleanTerm", "box('fk5', 2, 3, 3, 0)=x"),
		("booleanTerm", "POLYGON('fk5', 2, 3, 3, 0, 23, 0, 45, 34)=x"),
		("booleanTerm", "REGION('mainfranken')=x"),
#25
		("booleanTerm", "CENTROID(CIRCLE('fk4', 2, 3, 3))=x"),
		("nonPredicateGeometryFunction", "DISTANCE(a, b)"),
		("nonPredicateGeometryFunction", "DISTANCE(POINT(a, b), POINT(c,d))"),
		("nonPredicateGeometryFunction", "DISTANCE(a, POINT(c, d))"),
	]


@unittest.skipUnless(testhelpers.hasUDF("GAVO_MOCUNION"), "pgsphere too old")
class GoodMOCTest(_GoodExamplesBase):
	samples = [
		("geometryValueExpression", "MOC('3/2,3')"),
		("moc", "MOC(3, point('ICRS', 3, 4))"),
		("moc", "MOC(3, POLYGON(3, 4, 5, 6, 7,8))"),
		("moc", "MOC(ordr, CIRCLE(3, 4, 8))"),
		("moc", "MOC(ordr, s_region)"),
		("moc", "MOC(ordr, MOC('4/6,7'))"),
		("moc", "MOC(ordr, MOC(4, CIRCLE(3,4,5)))"),
		("moc", "MOC(ordr, gavo_moc_union(MOC(3, CIRCLE(3,4,5)), MOC('0/1')))"),
	]


class GoodStatementTests(_GoodExamplesBase):
	samples = [("statement", s) for s in [
		"SELECT ivo_foo(6) || 9\n from Y",
		"SELECT ivo_foo('11') || 11 from Y",
		"SELECT 'text'*1323\n from Y",
		"select * from foo where 1=CONTAINS(\n"
		"  REGION('Union ICRS (Position 1 2 Intersection"
		"   (circle  1 2 3 box 1 2 3 4 circle 30 40 2))'),\n"
		"  REGION('circle GALACTIC 1 2 3'))",
		"SELECT count(*) FROM x\nWHERE COALESCE(a, b, c) = 0",
# 5
		"select top 4* from (select * from amanda.nucand)x",
		"select * from bar order   by foo asc, 2desc",
		"(SELECT TOP 10 id, ra, dec FROM atable ORDER BY id ASC)"
		" UNION (SELECT TOP 10 id, ra, dec FROM atable ORDER BY id DESC)",
	]]


class BadStatementTests(_BadExamplesBase):
	samples = [("statement", s) for s in [
		"SELECT point('icrs', 3, 4) || 'aaaa'\n from Y",
	]]


class _ADQLParsesTest(testhelpers.VerboseTest):
	"""an abstract base for tests checking whether ADQL expressions parse.
	"""
	def setUp(self):
		_, self.grammar = adql.getRawGrammar()
		testhelpers.VerboseTest.setUp(self)

	def _assertGoodADQL(self, statement):
		try:
			self.grammar.parseString(statement)
		except (adql.ParseException, adql.ParseSyntaxException):
			raise AssertionError("%s doesn't parse but should."%statement)
		except RuntimeError:
			raise Error("%s causes an infinite recursion"%statement)

	def _assertBadADQL(self, statement):
			try:
				self.assertRaisesVerbose(
					(adql.ParseException,adql.ParseSyntaxException),
					self.grammar.parseString, (statement,),
					"Parses but shouldn't: %s"%statement)
			except RuntimeError:
				raise Error("%s causes an infinite recursion"%statement)


class NodeTest(_ADQLParsesTest):
	def setUp(self):
		self.grammar = adql.getGrammar()

	def testNoChild(self):
		node = self.grammar.parseString("SELECT a, b, c FROM x")
		fields = node[0].children[0].children[0].selectList.selectFields
		self.assertRaisesWithMsg(
			adql.NoChild,
			"No ab child found in [<ADQL Node derivedColumn>,"
			" <ADQL Node derivedColumn>, <ADQL Node derivedColumn>]",
			nodes.getChildOfType,
			(fields, "ab"))

	def testMoreChild(self):
		node = self.grammar.parseString("SELECT a, b, c FROM x")
		fields = node[0].children[0].children[0].selectList.selectFields
		self.assertRaisesWithMsg(
			adql.MoreThanOneChild,
			"Multiple derivedColumn children found in [<ADQL Node derivedColumn>,"
			" <ADQL Node derivedColumn>, <ADQL Node derivedColumn>]",
			nodes.getChildOfType,
			(fields, "derivedColumn"))


class NakedParseTest(_ADQLParsesTest):
	"""tests for plain parsing (without tree building).
	"""
	def _assertParse(self, correctStatements):
		for stmt in correctStatements:
			self._assertGoodADQL(stmt)

	def _assertDontParse(self, badStatements):
		for stmt in badStatements:
			self._assertBadADQL(stmt)

	def testPlainSelects(self):
		"""tests for non-errors on some elementary select expressions parse.
		"""
		self._assertParse([
				"SELECT x FROM y",
				"SELECT x FROM y WHERE z=0",
				"SELECT x, v FROM y WHERE z=0 AND v>2",
				"SELECT 89 FROM X",
				"SELECT 89 FROM X AS Y",
				"SELECT 89 FROM X Y",
			])

	def testDelimited(self):
		self._assertParse([
			'SELECT "f-bar", "c""ho" FROM "nons-ak" WHERE "ja ja"<"Umph"'])

	def testSimpleSyntaxErrors(self):
		"""tests for rejection of gross syntactic errors.
		"""
		self._assertDontParse([
				"W00T",
				"SELECT A",
				"SELECT A FROM",
				"SELECT A FROM B WHERE",
				"SELECT FROM",
				"SELECT 89! FROM z",
			])

	def testCaseInsensitivity(self):
		"""tests for case being ignored in SQL keywords.
		"""
		self._assertParse([
				"select z as U From n",
				"seLect z AS U FROM n",
			])

	def testJoins(self):
		"""tests for JOIN syntax.
		"""
		self._assertParse([
			"select x from t1, t2",
			"select x from t1, t2, t3",
			"select x from t1, t2, t3 WHERE t1.x=t2.y",
			"select x from t1 JOIN t2",
			"select x from t1 NATURAL JOIN t2",
			"select x from t1 LEFT OUTER JOIN t2",
			"select x from t1 RIGHT OUTER JOIN t2",
			"select x from t1 FULL OUTER JOIN t2",
			"select x from t1 FULL OUTER JOIN t2 ON (x=y)",
			"select x from t1 FULL OUTER JOIN t2 USING (x,y)",
			"select x from t1 INNER JOIN (t2 JOIN t3)",
			"select x from (t1 JOIN t4) FULL OUTER JOIN (t2 JOIN t3)",
			"select x from t1 NATURAL JOIN t2, t3",
		])

	def testBadJoins(self):
		"""tests for syntax error detection in JOINs.
		"""
		self._assertDontParse([
			"select x from t1 JOIN",
			"select x from JOIN t1",
			"select x from t1 join JOIN t1",
			"select x from t1 NATURAL JOIN t2, t3 OUTER",
			"select x from t1 NATURAL JOIN t2, t3 ON",
			"select x from t1, t2, t3 ON",
		])

	def testBadDetritus(self):
		"""tests for syntax errors in ORDER BY and friends.
		"""
		self._assertDontParse([
			"select x from t1 having y",
		])

	def testBadBooleanTerms(self):
		p = "select x from y where "
		self._assertDontParse([
			p+"z BETWEEN",
			p+"z BETWEEN AND",
			p+"z BETWEEN AND 5",
			p+"z 7 BETWEEN 5 AND ",
			p+"x IN",
			p+"x IN 5",
			p+"x IN (23, 3,)",
			p+"x Is None",
		])
	
	def testsBadFunctions(self):
		"""tests for rejection of bad function calls.
		"""
		p = "select x from y where "
		self._assertDontParse([
			p+"ABS()<3",
			p+"ABS(y,z)<3",
			p+"ATAN2(x)<3",
			p+"PI==3",
		])
	
	def testFunkyIds(self):
		"""tests for parsing quoted identifiers.
		"""
		p = "select x from y where "
		self._assertParse([
			p+'"some weird column">0',
			p+'"some even ""weirder"" column">0',
			p+'"SELECT">0',
		])

	def testMiscGood(self):
		"""tests for parsing of various legal statements.
		"""
		self._assertParse([
			"select a, b from (select * from x) AS q",
		])

	def testMiscBad(self):
		"""tests for rejection of various bad statements.
		"""
		self._assertDontParse([
			"select a, b from (select * from x) q r",
			"select a, b from (select * from x)",
			"select x.y.z.a.b from a",
			"select x from a.b.c.d",
		])

	def testStringExpressionSelect(self):
		self._assertParse([
			"select m || 'ab' from q",])


class FunctionsParseTest(_ADQLParsesTest, metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		self._assertGoodADQL("select x from y where "+sample)

	samples = [
		"ABS(-3)<3",
		"ABS(-3.0)<3",
		"ABS(-3.0E4)<3",
		"ABS(-3.0e-4)<3",
		"ABS(x)<3",
		"ATAN2(-3.0e-4, 4.5)=x",
		"RAND(4)=x",
		"RAND()=x",
		"ROUND(23)=x",
		"ROUND(23,2)=x",
		"ROUND(PI(),2)=3.14",
		"POWER(x,10)=3.14",
		"POWER(10,x)=3.14",
	]


class SetExpressionsTest(_ADQLParsesTest, metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		self._assertGoodADQL(sample)
	
	samples = [
		"select x from t1 union select x from t2",
		"select x from t1 intersect select x from t2",
		"select x from t1 except select x from t2",
		"select x from t1 where x>2 union select x from t2",
		"select * from t1 union select x from t2 intersect select x from t3"
			" except select x from t4",
# 5
		"select * from (select * from t1 except select * from t2) as q union"
			" select * from  t3",
		"select * from t1 union select foo from (select * from t2 except select * from t1) as q",
	]


class AsTreeTest(testhelpers.VerboseTest):
	"""tests for asTree()
	"""
	def testSimple(self):
		t = adql.parseToTree("SELECT * FROM t WHERE 1=CONTAINS("
			"CIRCLE('ICRS', 4, 4, 2), POINT('', ra, dec))").asTree()
		self.assertEqual(t[1][1][1][1][1][0], 'possiblyAliasedTable')
		self.assertEqual(t[1][1][1][3][0], 'whereClause')
		self.assertEqual(t[1][1][1][3][1][2][1][0], 'circle')


class _TreeParseTestBase(testhelpers.VerboseTest):
	def setUp(self):
		self.grammar = adql.getGrammar()

	def assertAttrs(self, c, **assertions):
		for k,v in list(assertions.items()):
			self.assertEqual(getattr(c, k), v,
				"%s, %s!=%s"%(k, repr(getattr(c, k)), repr(v)))


class TreeParseTest(_TreeParseTestBase):
	def testSelectList(self):
		for q, e in [
			("select a from z", ["a"]),
			("select x.a from z", ["a"]),
			("select x.a, b from z", ["a", "b"]),
			('select "one weird name", b from z',
				[utils.QuotedName('one weird name'), "b"]),
		]:
			tree = self.grammar.parseString(q)[0]
			res = [c.name for c in tree.getSelectFields()]
			self.assertEqual(res, e,
				"Select list from %s: expected %s, got %s"%(q, e, res))

	def testSourceTables(self):
		for q, e in [
			("select * from z", ["z"]),
			("select * from z.x", ["z.x"]),
			("select * from z.x.y", ["z.x.y"]),
			("select * from z.x.y, a", ["z.x.y", "a"]),
			("select * from (select * from z) as q, a", ["q", "a"]),
		]:
			res = list(self.grammar.parseString(q)[0].getAllNames())
			self.assertEqual(res, e,
				"Source tables from %s: expected %s, got %s"%(q, e, res))

	def testSourceTablesJoin(self):
		for q, e in [
			("select * from z join x", ["z", "x"]),
			("select * from (select * from a,b, (select * from c,d) as q) as r join"
				"(select * from x,y) as p", ["r", "p"]),
		]:
			res = list(self.grammar.parseString(q)[0].getAllNames())
			self.assertEqual(res, e,
				"Source tables from %s: expected %s, got %s"%(q, e, res))

	def testContributingTables(self):
		q = ("select * from (select * from urks.a,b,"
			" (select * from c,monk.d) as q) as r join"
			" (select * from x,y) as p")
		self.grammar.parseString(q)[0].getContributingNames()
		self.assertEqual(self.grammar.parseString(q)[0].getContributingNames(),
			set(['c', 'b', 'urks.a', 'q', 'p', 'r', 'y', 'x', 'monk.d']))

	def testAliasedColumn(self):
		q = "select foo+2 as fp2 from x"
		res = self.grammar.parseString(q)[0]
		field = list(res.getSelectFields())[0]
		self.assertEqual(field.name, "fp2")
	
	def testTainting(self):
		for q, (exName, exTaint) in [
			("select x from z", ("x", False)),
			("select x as u from z", ("u", False)),
			("select x+2 from z", (None, True)),
			('select x+2 as "99 Monkeys" from z', (utils.QuotedName("99 Monkeys"),
				True)),
			('select x+2 as " ""cute"" Monkeys" from z',
				(utils.QuotedName(' "cute" Monkeys'), True)),
		]:
			res = list(self.grammar.parseString(q)[0].getSelectFields())[0]
			self.assertEqual(res.tainted, exTaint, "Field taintedness wrong in %s"%
				q)
			if exName:
				self.assertEqual(res.name, exName)

	def testValueExpressionColl(self):
		t = adql.parseToTree("select x from z where 5+9>'gaga'||'bla'")
		compPred = t.whereClause.children[1]
		self.assertEqual(compPred.op1.type, "numericValueExpression")
		self.assertEqual(compPred.opr, ">")
		self.assertEqual(compPred.op2.type, "stringValueExpression")

	def testQualifiedStar(self):
		t = adql.parseToTree("select t1.*, s1.t2.* from t1, s1.t2, s2.t3")
		self.assertEqual(t.selectList.selectFields[0].type, "qualifiedStar")
		self.assertEqual(t.selectList.selectFields[0].sourceTable.qName,
			"t1")
		self.assertEqual(t.selectList.selectFields[1].sourceTable.qName,
			"s1.t2")

	def testBadSystem(self):
		self.assertRaises(adql.ParseSyntaxException,
			self.grammar.parseString, "select point('QUARK', 1, 2) from spatial")

	def testQuotedTableName(self):
		t = adql.parseToTree('select "abc-g".* from "abc-g" JOIN "select"')
		self.assertEqual(t.selectList.selectFields[0].sourceTable.name, "abc-g")
		self.assertEqual(t.selectList.selectFields[0].sourceTable.qName, '"abc-g"')

	def testQuotedSchemaName(self):
		t = adql.parseToTree('select * from "Murks Schema"."Murks Tabelle"')
		table = t.fromClause.tableReference
		self.assertEqual(table.tableName.name,
			utils.QuotedName("Murks Tabelle"))
		self.assertEqual(table.tableName.schema,
			utils.QuotedName("Murks Schema"))
	
	def testSetLimitInherited(self):
		t = adql.parseToTree('select top 3 * from t1 union'
			' select top 4 * from t2 except select * from t3')
		self.assertEqual(t.setLimit, 4)
	
	def testSetLimitDeep(self):
		t = adql.parseToTree(
			'select top 7 * from t1 union'
			' (select top 4 * from t2 except select * from t3)'
			' except (select top 30 x from t4 except select top 3 y from t5)')
		self.assertEqual(t.setLimit, 30)

	def testHexadecimal(self):
		t = adql.parseToTree(
			"select x-0xaf, -0x1fffffff from t1")
		sels = list(t.getSelectClauses())[0].selectList.selectFields
		self.assertEqual(sels[0].flatten(), 'x - 175')
		self.assertEqual(sels[1].flatten(), '- 536870911')


class CircleTreeParseTest(_TreeParseTestBase):
	def _getParsed(self, circleLiteral):
		t = adql.parseToTree('select %s from t1'%circleLiteral)
		return t.children[0].selectList.selectFields[0].expr	

	def testColrefNosys(self):
		c = self._getParsed("CIRCLE(x, r)")
		self.assertAttrs(c, cooSys="")
		self.assertAttrs(c.radius, type="columnReference", name="r")
		self.assertAttrs(c.center, type="columnReference", name="x")

	def testColrefWithsys(self):
		c = self._getParsed("CIRCLE('ICRS', \"center\", \"radius\")")
		self.assertAttrs(c, cooSys="ICRS")
		self.assertAttrs(c.radius, type="columnReference", name="radius")
		self.assertAttrs(c.center, type="columnReference", name="center")

	def testSplitNosys(self):
		c = self._getParsed("CIRCLE(a, d, r)")
		self.assertAttrs(c, cooSys="")
		self.assertAttrs(c.radius, type="columnReference", name="r")
		self.assertAttrs(c.center.x, type="columnReference", name="a")
		self.assertAttrs(c.center.y, type="columnReference", name="d")

	def testLiteralPoint(self):
		c = self._getParsed("CIRCLE(NULL, POINT(a, d) ,r)")
		self.assertAttrs(c, cooSys="UNKNOWN")
		self.assertAttrs(c.radius, type="columnReference", name="r")
		self.assertAttrs(c.center.x, type="columnReference", name="a")
		self.assertAttrs(c.center.y, type="columnReference", name="d")

	def testLiteralPointNosys(self):
		c = self._getParsed("CIRCLE(POINT('ICRS', a, d) ,r)")
		self.assertAttrs(c, cooSys="")
		self.assertAttrs(c.radius, type="columnReference", name="r")
		self.assertAttrs(c.center.x, type="columnReference", name="a")
		self.assertAttrs(c.center.y, type="columnReference", name="d")

	def testSplitWithExpression(self):
		c = self._getParsed("CIRCLE(ra+2, dec ,r-1)")
		self.assertAttrs(c, cooSys="")
		self.assertAttrs(c.radius, type="numericValueExpression")
		self.assertAttrs(c.center.x, type="numericValueExpression")
		self.assertAttrs(c.center.y, type="columnReference", name="dec")
		self.assertEqual(c.radius.children[1], '-')

	@unittest.skipUnless(testhelpers.hasUDF("IVO_EPOCH_PROP"), "pgsphere too old")
	def testGeoUDF(self):
		c = self._getParsed("CIRCLE(ivo_apply_pm(12, 13, 1e-7, -1e-7, 19) ,r)")
		self.assertEqual(c.flatten(),
			"CIRCLE(IVO_APPLY_PM(12, 13, 1e-7, - 1e-7, 19),r)")


class PolygonTreeParseTest(_TreeParseTestBase):
	def _getParsed(self, polygonLiteral):
		t = adql.parseToTree('select %s from t1'%polygonLiteral)
		return t.children[0].selectList.selectFields[0].expr	

	def testColrefNosys(self):
		p = self._getParsed("POLYGON(a, b, c, d)")
		self.assertAttrs(p, type="polygon", cooSys="", coos=None)
		self.assertEqual(len(p.points), 4)
		self.assertAttrs(p.points[0], type="columnReference", name="a")
		self.assertAttrs(p.points[-1], type="columnReference", name="d")

	def testColrefWithsys(self):
		p = self._getParsed("polygon('ICRS', \"p 1\", \"p 2\", p3)")
		self.assertAttrs(p, cooSys="ICRS", coos=None)
		self.assertEqual(len(p.points), 3)
		self.assertAttrs(p.points[0], type="columnReference", name="p 1")
		self.assertAttrs(p.points[1], type="columnReference", name="p 2")
		self.assertAttrs(p.points[2], type="columnReference", name="p3")

	def testSplitNosys(self):
		p = self._getParsed("polygon(x1, y1, x2, y2, x3, y3)")
		self.assertAttrs(p, cooSys="", points=None)
		self.assertEqual(len(p.coos), 3)
		self.assertAttrs(p.coos[0][0], type="columnReference", name="x1")
		self.assertAttrs(p.coos[-1][1], type="columnReference", name="y3")

	def testLiteralPoints(self):
		p = self._getParsed("POLYGON(NULL, POINT(x1, y1),"
			" POINT('ICRS', x2, y2), POINT(x3,y3))")
		self.assertAttrs(p, cooSys="UNKNOWN", points=None)
		self.assertEqual(len(p.coos), 3)
		self.assertAttrs(p.coos[0][0], type="columnReference", name="x1")
		self.assertAttrs(p.coos[-1][1], type="columnReference", name="y3")

	def testLiteralPointNosys(self):
		p = self._getParsed("POLYGON(NULL, POINT(x1, y1),"
			" p, POINT(x3,y3))")
		self.assertAttrs(p, cooSys="UNKNOWN", coos=None)
		self.assertEqual(len(p.points), 3)
		self.assertAttrs(p.points[0].x, type="columnReference", name="x1")
		self.assertAttrs(p.points[-1].y, type="columnReference", name="y3")

	def testSplitWithExpression(self):
		p = self._getParsed(
			"polygon(ra+1, dec+1 ,ra+1, dec-1, ra-1, dec+1, ra-1, dec-1)")
		self.assertAttrs(p, cooSys="", points=None)
		self.assertAttrs(p.coos[0][0], type="numericValueExpression")
		self.assertEqual(p.coos[-1][1].children[1], "-")

	@unittest.skipUnless(testhelpers.hasUDF("IVO_EPOCH_PROP"), "pgsphere too old")
	def testGeoUDF(self):
		p = self._getParsed("polygon(ivo_apply_pm(12, 13, 1e-7, -1e-7, 19),"
			"ivo_apply_pm(12, 13, -1e-7, 1e-7, -19),"
			"ivo_apply_pm(12, 13, 1e-7, -1e-7, -19),"
			"ivo_apply_pm(12, 13, 1e-7, 1e-7, 19))")
		self.assertEqual(len(p.points), 4)
		self.assertEqual(p.flatten(),
			"POLYGON(IVO_APPLY_PM(12, 13, 1e-7, - 1e-7, 19),"
			" IVO_APPLY_PM(12, 13, - 1e-7, 1e-7, - 19),"
			" IVO_APPLY_PM(12, 13, 1e-7, - 1e-7, - 19),"
			" IVO_APPLY_PM(12, 13, 1e-7, 1e-7, 19))")


class DistanceParseTest(_TreeParseTestBase):
	def _getParsed(self, distanceLiteral):
		t = adql.parseToTree('select %s from t1'%distanceLiteral)
		return t.children[0].selectList.selectFields[0].expr	

	def testSplitArgs(self):
		df = self._getParsed("distance(1, dec1, ra2, dec2)")
		self.assertAttrs(df, funName="DISTANCE")
		self.assertEqual(len(df.args), 2)
		self.assertAttrs(df.args[0], type="point")
		self.assertAttrs(df.args[1].y, type="columnReference", name="dec2")

	def testPointArgs(self):
		df = self._getParsed("distance(p1, p2)")
		self.assertAttrs(df, funName="DISTANCE")
		self.assertEqual(len(df.args), 2)
		self.assertAttrs(df.args[0], type="columnReference", name="p1")
		self.assertAttrs(df.args[-1], type="columnReference", name="p2")

	def testPointArgsLiterals(self):
		df = self._getParsed("distance(point(ra1, dec1), point(ra2, dec2))")
		self.assertAttrs(df, funName="DISTANCE")
		self.assertEqual(len(df.args), 2)
		self.assertAttrs(df.args[0], type="point")
		self.assertAttrs(df.args[1].y, type="columnReference", name="dec2")

	def testPointArgsOneLiteral(self):
		df = self._getParsed("distance(point(ra1, dec1), p2)")
		self.assertAttrs(df, funName="DISTANCE")
		self.assertEqual(len(df.args), 2)
		self.assertAttrs(df.args[0], type="point")
		self.assertAttrs(df.args[0].x, type="columnReference", name="ra1")
		self.assertAttrs(df.args[1], type="columnReference", name="p2")


class ParseErrorTest(testhelpers.VerboseTest, metaclass=testhelpers.SamplesBasedAutoTest):
	"""tests for sensible error messages.
	"""

	def _runTest(self, sample):
		query, msgFragment = sample
    # pyparsing has changed its quotes over the time; cope with both
    # versions:
		msgFragment = msgFragment.replace('"', """['"]""")
		try:
			_ = adql.getGrammar().parseString(query, parseAll=True)
		except (adql.ParseException, adql.ParseSyntaxException) as ex:
			msg = str(ex)
			self.assertTrue(re.search(msgFragment, msg),
				"'%s' does not contain '%s'"%(msg, msgFragment))
		else:
			self.fail("'%s' parses but should not"%query)

	samples = [
		("", r'Expected SELECT.* \(at char 0'),
		("select mag from %s", r'Expected table reference.* \(at char 16'),
		("SELECT TOP foo FROM x", r'Expected unsigned integer.* \(at char 11'),
		("SELECT FROM x", r'Expected select list.* \(at char 7'),
		("SELECT x, FROM y", r'Expected select list item.* \(at char 10'),
#5
		("SELECT * FROM distinct", r'Expected table reference.* \(at char 14'),
		("SELECT DISTINCT FROM y", r'Expected select list.* \(at char 16'),
		("SELECT *", r'Expected FROM.* \(at char 8'),
		("SELECT * FROM y WHERE", r'Expected boolean expression.* \(at char 21'),
		("SELECT * FROM y WHERE y u 2",
			r'Expected boolean expression.* \(at char 24'),
# 10
		("SELECT * FROM y WHERE y < 2 AND",
			r'Expected boolean expression.* \(at char 31'),
		("SELECT * FROM y WHERE y < 2 OR",
			r'Expected boolean expression.* \(at char 30'),
		("SELECT * FROM y WHERE y IS 3", r'Expected NULL.* \(at char 27'),
		("SELECT * FROM y WHERE CONTAINS(a,b)",
			r'Expected boolean expression.* \(at char 35'),
		("SELECT * FROM y WHERE 1=CONTAINS(POINT('ICRS',x,COUNT)"
			" ,CIRCLE('ICRS',x,y,z))",
			r"Expected numeric expression, found '\)'  \(at char 53"),
# 15
		("SELECT * FROM (SELECT * FROM x)",
			r'Expected table reference.* \(at char 31'),
		("SELECT * FROM x WHERE EXISTS z", r'Expected subquery.* \(at char 29'),
		("SELECT POINT('junk', 3,4) FROM z",
			r'.*xpected "\)", found \',\'  \(at char 22'),
		("SELECT * from a join b on foo",
			r"Expected boolean expression.* \(at char 29"),
		("SELECT * from a OFFSET 20 join b on foo",
			r"Expected end of text.* \(at char 26"),
# 20
		("SELECT * from a natural join b OFFSET banana",
			r"Expected unsigned integer.* \(at char 38"),
		("select * from ivoa.obscore where 1=contains(point(s_ra, s_dec),"
			" circle(120, 30))",
			r'Expected {{Numeric expression "," Numeric expression "," - Numeric expression} | {{User defined function | {"POINT" - "(" [{coordinate system literal (ICRS, GALACTIC,...) - ","}] Numeric expression "," Numeric expression ")"} | column reference} "," - Numeric expression}}.* \(at char 78'),
		("CREATE TABLE grube.quatsch AS SELECT * from foo.bar",
			"Expected \"tap_user\", found 'g'  \(at char 13\)"),
	]


class JoinTypeTest(testhelpers.VerboseTest, metaclass=testhelpers.SamplesBasedAutoTest):
	sym = adql.getSymbols()["joinedTable"]

	def _collectJoinTypes(self, joinedNode):
		res = []
		if hasattr(joinedNode.leftOperand, "leftOperand"):
			res.extend(self._collectJoinTypes(joinedNode.leftOperand))
		res.append(joinedNode.getJoinType())
		if hasattr(joinedNode.rightOperand, "leftOperand"):
			res.extend(self._collectJoinTypes(joinedNode.rightOperand))
		return res

	def _runTest(self, sample):
		query, joinType = sample
		self.assertEqual(
			self._collectJoinTypes(self.sym.parseString(query)[0]), joinType)
	
	samples = [
		("a CROSS JOIN b", ["CROSS"]),
		("a join b", ["NATURAL"]),
		("a join b using (x)", ["USING"]),
		("a CROSS JOIN b CROSS JOIN c", ["CROSS", "CROSS"]),
		("a CROSS JOIN b join c", ["CROSS", "NATURAL"]),
#5
		("a join b cross join c", ["NATURAL", "CROSS"]),
		("a join b on (x=y) cross join c", ["CROSS", "CROSS"]),
		("a join b using (x,y) join c", ["USING", "NATURAL"]),
		("a join b using (x,y) join c using (z,v)", ["USING", "USING"]),
		("(a join b using (x,y)) join c using (z,v)", ["USING", "USING"]),
# 10
		("(a join b) cross join (c join d)", ["NATURAL", "CROSS", "NATURAL"]),
	]


_largeTable = MS(rscdef.TableDef, nrows=1000000)
_smallTable = MS(rscdef.TableDef, nrows=1000)

spatialFields = [
	MS(rscdef.Column, name="dist", ucd="phys.distance", unit="m"),
	MS(rscdef.Column, name="width", ucd="phys.dim", unit="m"),
	MS(rscdef.Column, name="height", ucd="phys.dim", unit="km"),
	MS(rscdef.Column, name="ra1", ucd="pos.eq.ra", unit="deg",
		parent_=_smallTable, tablehead="Raw RA"),
	MS(rscdef.Column, name="ra2", ucd="pos.eq.ra", unit="rad"),
	MS(rscdef.Column, name="gibtnet", ucd="invalid", unit="junk", hidden=True),]
spatial2Fields = [
	MS(rscdef.Column, name="ra1", ucd="pos.eq.ra;meta.main", unit="deg",
		parent_=_largeTable),
	MS(rscdef.Column, name="dec", ucd="pos.eq.dec;meta.main", unit="deg"),
	MS(rscdef.Column, name="dist", ucd="phys.distance", unit="m"),
	MS(rscdef.Column, name="t", ucd="time.epoch", unit="h")]
miscFields = [
	MS(rscdef.Column, name="mass", ucd="phys.mass", unit="kg"),
	MS(rscdef.Column, name="mag", ucd="phot.mag", unit="mag"),
	MS(rscdef.Column, name="speed", ucd="phys.veloc", unit="km/s")]
quotedFields = [
	MS(rscdef.Column, name=utils.QuotedName("left-right"), ucd="mess",
		unit="bg"),
	MS(rscdef.Column, name=utils.QuotedName('inch"ing'), ucd="imperial.mess",
		unit="fin"),
	MS(rscdef.Column, name=utils.QuotedName('plAin'), ucd="boring.stuff",
		unit="pc"),
	MS(rscdef.Column, name=utils.QuotedName('alllower'), ucd="simple.case",
		unit="km"),]
crazyFields = [
	MS(rscdef.Column, name="ct", type="integer"),
	MS(rscdef.Column, name="wot", type="bigint",
		values=MS(rscdef.Values, nullLiteral="-1")),
	MS(rscdef.Column, name="wotb", type="bytea",
		values=MS(rscdef.Values, nullLiteral="254")),
	MS(rscdef.Column, name="mass", ucd="event;using.incense"),
	MS(rscdef.Column, name="name", type="unicode"),
	MS(rscdef.Column, name="version", type="text"),
	MS(rscdef.Column, name="flag", type="char"),
	MS(rscdef.Column, name="vals", type="real[]", ucd="some.value", unit="yr")]
geoFields = [
	MS(rscdef.Column, name="pt", type="spoint"),
	MS(rscdef.Column, name="coverage", type="spoly"),
	MS(rscdef.Column , name="dt", type="timestamp", ucd="time;obs"),
]

def _addSpatialSTC(sf, sf2, geo):
	ast1 = stc.parseQSTCS('Position ICRS "ra1" "dec" Size "width" "height"')
	ast2 = stc.parseQSTCS('Position FK4 SPHER3 "ra2" "dec" "dist"')
	# XXX TODO: get utypes from ASTs
	sf[0].stc, sf[0].stcUtype = ast2, None
	sf[1].stc, sf[1].stcUtype = ast1, None
	sf[2].stc, sf[2].stcUtype = ast1, None
	sf[3].stc, sf[3].stcUtype = ast1, None
	sf[4].stc, sf[4].stcUtype = ast2, None
	sf2[0].stc, sf2[0].stcUtype = ast1, None
	sf2[1].stc, sf2[0].stcUtype = ast1, None
	sf2[2].stc, sf2[0].stcUtype = ast2, None
	ast3 = stc.parseQSTCS('Time TT BARYCENTER "dt" Position GALACTIC [pt]')
	geo[0].stc = ast3
	geo[1].stc = ast3
_addSpatialSTC(spatialFields, spatial2Fields, geoFields)


def _addFakeIndices(sf):
	assert sf[3].name=='ra1'
	sf[3].isIndexed = lambda: ['q3c']
_addFakeIndices(spatialFields)


class _MTH(object):
	@classmethod
	def getTableDefForTable(cls, tableName):
		return {
			'spatial': spatialFields,
			'spatial2': spatial2Fields,
			'misc': miscFields,
			'quoted': quotedFields,
			'crazy': crazyFields,
			'geo': geoFields}.get(tableName)


class _SampleFieldInfoGetter(adqlglue.DaCHSFieldInfoGetter):
	def __init__(self, *args):
		adqlglue.DaCHSFieldInfoGetter.__init__(self)
		self.mth = _MTH


@functools.lru_cache(None)
def _getFieldInfoGetter():
	return _SampleFieldInfoGetter()


class TestFieldInfoGetter(testhelpers.VerboseTest):
	def testBasic(self):
		s2 = _getFieldInfoGetter()("spatial2")
		self.assertEqual(len(s2), 4)
		self.assertEqual(s2[3][0], "t")
		self.assertEqual(s2[3][1].unit, "h")
		self.assertEqual(s2[3][1].userData[0].unit, "h")


	def testNotFound(self):
		self.assertRaisesWithMsg(
			base.NotFoundError,
			"table 'does.not.exist' could not be located in system"
			" and uploaded tables",
			_getFieldInfoGetter(),
			("does.not.exist",))


def parseWithArtificialTable(query):
	parsedTree = adql.getGrammar().parseString(query)[0]
	_ = adql.annotate(parsedTree, _getFieldInfoGetter())
	return parsedTree


class TypecalcTest(testhelpers.VerboseTest, metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		inTypes, result = sample
		self.assertEqual(adql.getSubsumingType(inTypes), result)
	
	samples = [
		(["double precision", "integer", "bigint"], 'double precision'),
		(["date", "timestamp", "timestamp"], 'timestamp'),
		(["date", "boolean", "smallint"], 'text'),
		(["box", "raw"], 'raw'),
		(["date", "time"], 'timestamp'),
# 5
		(["char(3)", "integer"], "text"),
		(["double precision", "char(3)"], "text"),
		(["integer[3]", "bigint"], "bigint[]"),
		(["integer", "smallint", "double precision[]"], "double precision[]"),
		(["integer[][]", "smallint", "double precision[]"], "double precision[]"),
# 10
		# I would give you the next is plain wrong, but I'm relying on postgres
		# to reject such nonsence in the first place.
		(["double precision[340]", "char(3)"], "text"),
		(["boolean", "boolean"], "boolean"),
		(["boolean", "smallint"], "smallint"),
		(["sbox", "spoint"], "text"),
		(["sbox", "spoly"], "spoly"),
		(["sbox", "whacko"], "raw"),
	]


class ColumnTest(testhelpers.VerboseTest):
	def setUp(self):
		testhelpers.VerboseTest.setUp(self)
		self.fieldInfoGetter = _getFieldInfoGetter()
		self.grammar = adql.getGrammar()

	def _getColSeqAndCtx(self, query):
		t = self.grammar.parseString(query)[0]
		ctx = adql.annotate(t, self.fieldInfoGetter)
		return t.fieldInfos.seq, ctx

	def _getColSeq(self, query):
		return self._getColSeqAndCtx(query)[0]

	def _assertColumns(self, resultColumns, assertProperties):
		self.assertEqual(len(resultColumns), len(assertProperties))
		for index, ((name, col), (type, unit, ucd, taint)) in enumerate(zip(
				resultColumns, assertProperties)):
			if type is not None:
				self.assertEqual(col.type, type, "Type %d: %r != %r"%
					(index, col.type, type))
			if unit is not None:
				self.assertEqual(col.unit, unit, "Unit %d: %r != %r"%
					(index, col.unit, unit))
			if ucd is not None:
				self.assertEqual(col.ucd, ucd, "UCD %d: %r != %r"%
					(index, col.ucd, ucd))
			if taint is not None:
				self.assertEqual(col.tainted, taint, "Taint %d: should be %s"%
					(index, taint))


class SelectClauseTest(ColumnTest):
	def testConstantSelect(self):
		cols = self._getColSeq("select 1, 'const' from spatial")
		self._assertColumns(cols, [
			("smallint", "", "", False),
			("text", "", "", False),])

	def testConstantExprSelect(self):
		cols = self._getColSeq("select 1+0.1, 'const'||'ab' from spatial")
		self._assertColumns(cols, [
			("double precision", "", "", True),
			("text", "", "", True),])

	def testConstantSelectWithAs(self):
		cols = self._getColSeq("select 1+0.1 as x from spatial")
		self._assertColumns(cols, [
			("double precision", "", "", True),])

	def testSimpleColumn(self):
		cols = self._getColSeq("select mass from misc")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False),])

	def testBadRefRaises(self):
		self.assertRaises(adql.ColumnNotFound, self._getColSeq,
			"select x, foo.* from spatial, misc")

	def testQualifiedStarSingle(self):
		cols = self._getColSeq("select misc.* from misc")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False),
			("real", "mag", "phot.mag", False),
			("real", "km/s", "phys.veloc", False),])

	def testQualifiedStar(self):
		cols = self._getColSeq("select misc.* from spatial, misc")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False),
			("real", "mag", "phot.mag", False),
			("real", "km/s", "phys.veloc", False),])

	def testMixedQualifiedStar(self):
		cols = self._getColSeq("select misc.*, dist, round(mass/10)"
			" from spatial, misc")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False),
			("real", "mag", "phot.mag", False),
			("real", "km/s", "phys.veloc", False),
			("real", "m", "phys.distance", False),
			("double precision", "kg", "phys.mass", True),])

	def testAliasedStar(self):
		cols = self._getColSeq("select misc.* from spatial join misc as foo"
			" on (spatial.dist=foo.mass)")
		self.assertEqual(len(cols), 3)

	def testFancyRounding(self):
		cols = self._getColSeq("select round(dist, 2) from spatial")
		self._assertColumns(cols, [
			("double precision", "m", "phys.distance", True)])


class ColResTest(ColumnTest):
	"""tests for resolution of output columns from various expressions.
	"""
	def testSimpleSelect(self):
		cols = self._getColSeq("select width, height from spatial")
		self.assertEqual(cols[0][0], 'width')
		self.assertEqual(cols[1][0], 'height')
		wInfo = cols[0][1]
		self.assertEqual(wInfo.unit, "m")
		self.assertEqual(wInfo.ucd, "phys.dim")
		self.assertTrue(wInfo.userData[0] is spatialFields[1])

	def testNULL(self):
		cols = self._getColSeq("select NULL from spatial")
		self._assertColumns(cols, [
			(None, "", "", False)])

	def testIgnoreCase(self):
		cols = self._getColSeq("select Width, hEiGHT from spatial")
		self._assertColumns(cols, [
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),])

	def testStarSelect(self):
		cols = self._getColSeq("select * from spatial")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),
			("real", "deg", "pos.eq.ra", False),
			("real", "rad", "pos.eq.ra", False), ])

	def testStarSelectJoined(self):
		cols = self._getColSeq("select * from spatial, misc")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),
			("real", "deg", "pos.eq.ra", False),
			("real", "rad", "pos.eq.ra", False),
			("real", "kg", "phys.mass", False),
			("real", "mag", "phot.mag", False),
			("real", "km/s", "phys.veloc", False)])

	def testDimlessSelect(self):
		cols = self._getColSeq("select 3+4 from spatial")
		self.assertEqual(cols[0][1].type, "smallint")
		self.assertEqual(cols[0][1].unit, "")
		self.assertEqual(cols[0][1].ucd, "")

	def testSimpleScalarExpression(self):
		cols = self._getColSeq("select 2+width, 2*height, height*2"
			" from spatial")
		self._assertColumns(cols, [
			("real", "m", "", True),
			("real", "km", "phys.dim", True),
			("real", "km", "phys.dim", True),])
		self.assertTrue(cols[1][1].userData[0] is spatialFields[2])

	def testFieldOperandExpression(self):
		cols = self._getColSeq("select width*height, width/speed, "
			"3*mag*height, mag+height, height+height from spatial, misc")
		self._assertColumns(cols, [
			("real", "m*km", "", True),
			("real", "m/(km/s)", "", True),
			("real", "mag*km", "", True),
			("real", "", "", True),
			("real", "km", "phys.dim", True)])

	def testMiscOperands(self):
		cols = self._getColSeq("select -3*mag from misc")
		self._assertColumns(cols, [
			("real", "mag", "phot.mag", True)])

	def testSetFunctions(self):
		cols = self._getColSeq("select AVG(mag), mAx(mag), max(2*mag),"
			" Min(Mag), sum(mag), count(mag), avg(3), count(*) from misc")
		self._assertColumns(cols, [
			("double precision", "mag", "phot.mag;stat.mean", False),
			("real", "mag", "stat.max;phot.mag", False),
			("real", "mag", "stat.max;phot.mag", True),
			("real", "mag", "stat.min;phot.mag", False),
			("real", "mag", "phot.mag", False),
			("integer", "", "meta.number;phot.mag", False),
			("double precision", "", None, False),
			("integer", "", "meta.number", False)])

	def testNumericFunctions(self):
		cols = self._getColSeq("select acos(ra2), degrees(ra2), RadianS(ra1),"
			" PI(), ABS(width), Ceiling(Width), Truncate(height*2)"
			" from spatial")
		self._assertColumns(cols, [
			("double precision", "rad", "", True),
			("double precision", "deg", "pos.eq.ra", True),
			("double precision", "rad", "pos.eq.ra", True),
			("double precision", "", "", True),
			("double precision", "m", "phys.dim", True),
			("double precision", "m", "phys.dim", True),
			("double precision", "km", "phys.dim", True)])

	def testAggFunctions(self):
		cols = self._getColSeq("select max(ra1), min(ra1) from spatial")
		self._assertColumns(cols, [
			("real", "deg", "stat.max;pos.eq.ra", False),
			("real", "deg", "stat.min;pos.eq.ra", False)])

	def testPoint(self):
		cols = self._getColSeq("select point('ICRS', ra1, ra2) from spatial")
		self._assertColumns(cols, [
			("spoint", '', '', False)])
		self.assertTrue(cols[0][1].userData[0] is spatialFields[3])

	def testDistance(self):
		cols = self._getColSeq("select distance(point('galactic', 2, 3),"
			" point('ICRS', ra1, ra2)) from spatial")
		self._assertColumns(cols, [
			("double precision", 'deg', 'pos.angDistance', False)])

	def testCentroid(self):
		cols = self._getColSeq("select centroid(circle('galactic', ra1, ra2, 0.5))"
			" from spatial")
		self._assertColumns(cols, [
			("spoint", '', '', False)])

	def testArea(self):
		cols = self._getColSeq("select area(circle(ra1, ra2, 0.2))"
			" from spatial")
		self._assertColumns(cols, [
			("double precision", 'deg**2', 'phys.angSize', False)])

	def testParenExprs(self):
		cols = self._getColSeq("select (width+width)*height from spatial")
		self._assertColumns(cols, [
			("real", "m*km", "", True)])

	def testSubquery(self):
		cols = self._getColSeq("select q.p from (select ra2 as p from"
			" spatial) as q")
		self._assertColumns(cols, [
			("real", 'rad', 'pos.eq.ra', False)])

	def testSubqueryStar(self):
		cols = self._getColSeq("select p, speed, q.*"
			" from (select speed, mag as p from misc) as q")
		self._assertColumns(cols, [
				("real", "mag", "phot.mag", False),
				("real", "km/s", "phys.veloc", False),
				("real", "km/s", "phys.veloc", False),
				("real", "mag", "phot.mag", False)])

	def testJoin(self):
		cols = self._getColSeq("select dist, speed, 2*mass*height"
			" from spatial join misc on (mass>height)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'km/s', 'phys.veloc', False),
			("real", 'kg*km', '', True),])

	def testWhereResolutionPlain(self):
		cols = self._getColSeq("select dist from spatial where exists"
			" (select * from misc where dist=misc.mass)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False)])

	def testWhereResolutionWithAlias(self):
		cols = self._getColSeq("select dist from spatial as q where exists"
			" (select * from misc where q.dist=misc.mass)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False)])

	def testErrorReporting(self):
		self.assertRaises(adql.ColumnNotFound, self._getColSeq,
			"select gnurks from spatial")

	def testExpressionWithUnicode(self):
		cols = self._getColSeq("select crazy.name||geo.pt from crazy, geo")
		self._assertColumns(cols, [
			("unicode", '', '', True)])

	def testIdenticalNames(self):
		cols = self._getColSeq("SELECT u.ra1 FROM spatial AS mine"
  		" LEFT OUTER JOIN spatial2 as u"
  		" ON (1=CONTAINS(POINT('', mine.ra1, mine.ra2),"
  		"   CIRCLE('', u.ra1, u.dec, 1)))")
		self._assertColumns(cols, [
			("real", 'deg', 'pos.eq.ra;meta.main', False)])

	def testAliasedColumn(self):
		cols = self._getColSeq("SELECT foo, ra1 FROM ("
			"SELECT ra1 as foO, ra1 FROM spatial) as q")
		self._assertColumns(cols, [
			("real", 'deg', 'pos.eq.ra', False),
			("real", 'deg', 'pos.eq.ra', False)])

	def testCoalesce(self):
		cols = self._getColSeq("SELECT COALESCE(ra1, pt, '') as foo FROM"
			" spatial, geo")
		self._assertColumns(cols, [
			("real", 'deg', 'pos.eq.ra', True)])

	def testCaseWithCol(self):
		cols = self._getColSeq("SELECT CASE ra1 WHEN 3 Then  dist"
			" when 3 then height END as q FROM"
			" spatial")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', True)])

	def testCaseWithColExpr(self):
		cols = self._getColSeq("SELECT CASE ra1 WHEN 2 THEN ra1*14 WHEN 3"
			" Then  dist else height END as q FROM"
			" spatial")
		self._assertColumns(cols, [
			("real", 'deg', 'pos.eq.ra', True)])

	def testCaseElse(self):
		cols = self._getColSeq("SELECT CASE ra1 WHEN 2 THEN 14 WHEN 3"
			" Then  18 else height END as q FROM"
			" spatial")
		self._assertColumns(cols, [
			("real", 'km', 'phys.dim', True)])

	def testNestedCase(self):
		cols = self._getColSeq("SELECT CASE ra1 WHEN 2 THEN"
			" case when dist/width<1 THEN dist ELSE width END WHEN 3"
			" Then  18 else height END as q FROM spatial")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', True)])

	def testOnlyLiteralCases(self):
		cols = self._getColSeq("SELECT CASE WHEN ra1<2 THEN"
			" 'knall' ELSE 'fall' END as q FROM spatial")
		self._assertColumns(cols, [
			("text", '', '', True)])

	def testNULLLiteralInCase(self):
		cols = self._getColSeq("SELECT CASE WHEN ra1<2 THEN"
			" NULL ELSE 'fall' END as q FROM spatial")
		self._assertColumns(cols, [
			(None, '', '', True)])


class ExprColTest(ColumnTest):
	def testCharConcat(self):
		cols = self._getColSeq("select flag||'ab' as cat from crazy")
		self._assertColumns(cols, [
			("text", '', "", True),])

	def testTextConcat(self):
		cols = self._getColSeq("select version||'ab' as cat from crazy")
		self._assertColumns(cols, [
			("text", '', "", True),])

	def testUnicodeConcat(self):
		cols = self._getColSeq("select name||'ab' as cat from crazy")
		self._assertColumns(cols, [
			("unicode", '', "", True),])

	def testUCDColSimple(self):
		cols = self._getColSeq("select UCDCOL('phys.mass') from misc")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False)])

	def testUCDColPattern(self):
		cols = self._getColSeq("select UCDCOL('phys.mass'), UCDCOL('phys.dist*')"
			" from misc join spatial on (dist=mass)")
		self._assertColumns(cols, [
			("real", "kg", "phys.mass", False),
			("real", "m", "phys.distance", False)
			])

	def testUCDColFails(self):
		self.assertRaisesWithMsg(base.NotFoundError,
			"column matching ucd 'phys.mass' could not be located in from clause",
			self._getColSeq,
			("select UCDCOL('phys.mass') from spatial",))
	
	def testInUnit(self):
		cols = self._getColSeq("select IN_UNIT(height, 'm') from spatial")
		self._assertColumns(cols, [
			("real", "m", "phys.dim", False),])
	
	def testInUnitFailsIncompatibleUnit(self):
		self.assertRaisesWithMsg(adql.Error,
			"in_unit error: km and rad do not have the same SI base",
			self._getColSeq,
			("select IN_UNIT(height, 'rad') from spatial",))

	def testInUnitFailsBogusUnit(self):
		self.assertRaisesWithMsg(adql.Error,
			"Bad unit passed to in_unit: 'krach schlagen' at col. 7",
			self._getColSeq,
			("select IN_UNIT(height, 'krach schlagen') from spatial",))

	def testAggregateUDF(self):
		cols = self._getColSeq("select gavo_histogram(mass, 0, 100, 10) as h"
			" from misc")
		self._assertColumns(cols, [
			("integer[]", "", "stat.histogram;phys.mass", False)])

	def testArrayElement(self):
		cols = self._getColSeq("select vals[5] as h"
			" from crazy")
		self._assertColumns(cols, [
			("real", "yr", "some.value", True)])
	
	def testNonArraySubscript(self):
		self.assertRaisesWithMsg(adql.Error,
			"Cannot subscript a non-array in width [ 0 ]",
			self._getColSeq,
			("select width[0] from spatial",))

	def testNoNumberSubscript(self):
		self.assertRaisesWithMsg(base.ParseException,
			'Expected FROM, found \'[\'  (at char 8), (line:1, col:9)',
			self._getColSeq,
			("select 5[0] from spatial",))

	def testCountFunction(self):
		cols = self._getColSeq("select count(pt) from geo")
		self._assertColumns(cols, [
			("integer", "", "meta.number", False)])


class DelimitedColResTest(ColumnTest):
	"""tests for column resolution with delimited identifiers.
	"""
	def testCaseSensitive(self):
		self.assertRaises(adql.ColumnNotFound, self._getColSeq,
			'select "Inch""ing" from quoted')

	def testMixedCase(self):
		cols = self._getColSeq('select "plAin" from quoted')
		self.assertEqual(cols[0][0], utils.QuotedName("plAin"))

	def testNoFoldToRegular(self):
		self.assertRaises(adql.ColumnNotFound, self._getColSeq,
			'select plain from quoted')

	def testDelimitedMatchesRegular(self):
		cols = self._getColSeq('select "mass" from misc')
		self.assertEqual(cols[0][0], "mass")

	def testConstantSelectWithAs(self):
		cols = self._getColSeq('select 1+0.1 as "x" from spatial')
		self.assertEqual(cols[0][0], "x")

	def testRegularMatchesDelmitied(self):
		cols = self._getColSeq('select alllower from quoted')
		self.assertEqual(cols[0][0], "alllower")

	def testSimpleStar(self):
		cols = self._getColSeq("select * from quoted")
		self._assertColumns(cols, [
			("real", 'bg', "mess", False),
			("real", 'fin', "imperial.mess", False),
			("real", 'pc', "boring.stuff", False),
			("real", 'km', "simple.case", False),])
	
	def testSimpleJoin(self):
		cols = self._getColSeq('select "inch""ing", "mass" from misc join'
			' quoted on ("left-right"=speed)')
		self._assertColumns(cols, [
			("real", 'fin', "imperial.mess", False),
			("real", 'kg', 'phys.mass', False)])

	def testPlainAndSubselect(self):
		cols = self._getColSeq('select "inch""ing", alllower from ('
			'select TOP 5 * from quoted where alllower<"inch""ing") as q')
		self._assertColumns(cols, [
			("real", 'fin', "imperial.mess", False),
			("real", 'km', "simple.case", False),])
	
	def testQuotedExpressions(self):
		cols = self._getColSeq('select 4*alllower*"inch""ing" from quoted')
		self._assertColumns(cols, [
			("real", 'km*fin', None, True)])

	def testReferencingRegulars(self):
		cols = self._getColSeq('select "ra1" from spatial')
		self._assertColumns(cols, [
			("real", 'deg', "pos.eq.ra", False)])


class JoinColResTest(ColumnTest):
	def testJoin(self):
		cols = self._getColSeq("select dist, speed, 2*mass*height"
			" from spatial join misc on (mass>height)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'km/s', 'phys.veloc', False),
			("real", 'kg*km', '', True),])

	def testJoinStar(self):
		cols = self._getColSeq("select * from spatial as q join misc as p on"
			" (1=contains(point('ICRS', q.dist, q.width), circle('ICRS',"
			" p.mass, p.mag, 0.02)))")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'm', 'phys.dim', False),
			("real", 'km', 'phys.dim', False),
			("real", 'deg', 'pos.eq.ra', False),
			("real", 'rad', 'pos.eq.ra', False),
			("real", 'kg', 'phys.mass', False),
			("real", 'mag', 'phot.mag', False),
			("real", 'km/s', 'phys.veloc', False),
			])

	def testSubqueryJoin(self):
		cols = self._getColSeq("SELECT * FROM ("
  		"SELECT ALL q.mass, spatial.ra1 FROM ("
    	"  SELECT TOP 100 mass, mag FROM misc"
      "    WHERE speed BETWEEN 0 AND 1) AS q JOIN"
    	"  spatial ON (mass=width)) AS f")
		self._assertColumns(cols, [
			("real", 'kg', 'phys.mass', False),
			("real", 'deg', 'pos.eq.ra', False)])

	def testAutoJoin(self):
		cols = self._getColSeq("SELECT * FROM misc JOIN"
			" (SELECT TOP 3 * FROM crazy) AS q ON (mag=q.ct)")
		physMass = cols[0]
		self.assertEqual(physMass[0], "mass")
		self.assertEqual(physMass[1].ucd, "phys.mass")
		crazyMass = cols[6]
		self.assertEqual(crazyMass[0], "mass")
		self.assertEqual(crazyMass[1].ucd, "event;using.incense")

	def testSelfUsingJoin(self):
		cols = self._getColSeq("SELECT * FROM "
    	" misc JOIN misc AS u USING (mass)")
		self._assertColumns(cols, [
			("real", 'kg', 'phys.mass', False),
			("real", 'mag', 'phot.mag', False),
			("real", 'km/s', 'phys.veloc', False),
			("real", 'mag', 'phot.mag', False),
			("real", 'km/s', 'phys.veloc', False) ])

	def testExReferenceBad(self):
		self.assertRaises(adql.TableNotFound, self._getColSeq,
			"select foo.dist from spatial join misc on (mass>height)")

	def testExReference(self):
		cols = self._getColSeq("select a.dist, b.dist"
			" from spatial as a join spatial as b on (a.dist>b.dist)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'm', 'phys.distance', False)])

	def testExReferenceMixed(self):
		cols = self._getColSeq("select spatial.dist, b.speed"
			" from spatial as a join misc as b on (a.dist>b.speed)")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'km/s', 'phys.veloc', False)])
	
	def testNaturalJoin(self):
		cols = self._getColSeq("SELECT * FROM"
			" spatial JOIN spatial2 WHERE dist<2")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),
			("real", "deg", "pos.eq.ra", False),
			("real", "rad", "pos.eq.ra", False),
			("real", "deg", "pos.eq.dec;meta.main", False),
			("real", "h", "time.epoch", False)])

	def testNaturalJoinSubquery(self):
		cols = self._getColSeq("SELECT dist, width FROM"
			" spatial JOIN spatial2 WHERE dist IN (SELECT spatial2.dist FROM spatial2)")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False)])

	def testUsingJoin1(self):
		cols = self._getColSeq("SELECT * FROM"
			" spatial JOIN spatial2 USING (ra1)")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),
			("real", "deg", "pos.eq.ra", False),
			("real", "rad", "pos.eq.ra", False),
			("real", "deg", "pos.eq.dec;meta.main", False),
			("real", "m", "phys.distance", False),
			("real", "h", "time.epoch", False)])

	def testUsingJoin2(self):
		cols = self._getColSeq("SELECT * FROM"
			" spatial JOIN spatial2 USING (ra1, dist)")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "m", "phys.dim", False),
			("real", "km", "phys.dim", False),
			("real", "deg", "pos.eq.ra", False),
			("real", "rad", "pos.eq.ra", False),
			("real", "deg", "pos.eq.dec;meta.main", False),
			("real", "h", "time.epoch", False)])

	def testUsingJoin3(self):
		cols = self._getColSeq("SELECT ra1, dec, mass FROM"
			" spatial JOIN spatial2 USING (ra1, dist) JOIN misc ON (dist=mass)")
		self._assertColumns(cols, [
			("real", "deg", "pos.eq.ra", False),
			("real", "deg", "pos.eq.dec;meta.main", False),
			("real", "kg", "phys.mass", False),])

	def testUsingJoin4(self):
		cols = self._getColSeq("SELECT ra1, dec, mass FROM"
			" (SELECT * FROM spatial) as q JOIN spatial2"
			" USING (ra1, dist) JOIN misc ON (dist=mass)")
		self._assertColumns(cols, [
			("real", "deg", "pos.eq.ra", False),
			("real", "deg", "pos.eq.dec;meta.main", False),
			("real", "kg", "phys.mass", False),])
	
	def testCommaAll(self):
		cols = self._getColSeq("SELECT * from spatial, spatial, misc")
		self.assertEqual([c[1].userData[0].name for c in cols], [
			'dist', 'width', 'height', 'ra1', 'ra2', 'dist', 'width',
			'height', 'ra1', 'ra2', 'mass', 'mag', 'speed'])

	def testHaving1(self):
		cols = self._getColSeq(
			"SELECT ct FROM crazy "
			"JOIN ("
			"  SELECT height FROM spatial"
			"  JOIN spatial2 ON (ra2=dist)"
			"  GROUP BY height"
			"  HAVING (height>avg(dist))) AS q "
			"ON (wot=height)")
		self._assertColumns(cols, [
			('integer', '', '', False)])


class SetColResTest(ColumnTest):
	def testSimple(self):
		cols = self._getColSeq("select dist, height from spatial"
			" union select dist, height from spatial")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("real", 'km', 'phys.dim', False)])

	def testStars(self):
		cols = self._getColSeq("select 0 as i, misc.* from misc"
			" union select 1 as i, misc.* from misc")
		self._assertColumns(cols, [
			("smallint", '', '', False),
			("real", 'kg', 'phys.mass', False),
			("real", 'mag', 'phot.mag', False),
			("real", 'km/s', 'phys.veloc', False)])

	def testLengthFailure(self):
		self.assertRaisesWithMsg(adql.IncompatibleTables,
			"Operands in set operation have differing result tuple lengths.",
			self._getColSeq,
			("select dist, width, height from spatial"
			" union select dist, height from spatial",))

	def testName(self):
		self.assertRaisesWithMsg(adql.IncompatibleTables,
			"Operands if set operation have differing names.  First differing name: width vs. dist",
			self._getColSeq,
			("select width, height from spatial"
			" union select dist, height from spatial",))

	def testAliasing(self):
		cols = self._getColSeq("select dist, height from spatial"
			" union select dist, ra1 as height from spatial2"
			" intersect select mass as dist, mag as height from misc"
			' except select "left-right" as dist, "plAin" as height from quoted')
		self._assertColumns(cols, [
			("real", '', '', False),
			("real", '', '', False)])

	def testNested(self):
		cols = self._getColSeq("select dist, height from spatial"
			" union select ra1 as dist, dec as height from ("
			"   select ra1, dec, dist from spatial2"
			"   except select mag as ra1, mass as dec, speed as dist from misc) as q"
			"  where dist>2")
		self._assertColumns(cols, [
			("real", '', '', False),
			("real", '', '', False)])

	def testSetGeneratingFunction(self):
		cols = self._getColSeq("select * from generate_series(1, 4)")
		self._assertColumns(cols, [
			('integer', None, None, False)])
		self.assertEqual(cols[0][0], "generate_series")

	def testSetGeneratingFunctionAlias(self):
		cols = self._getColSeq("select * from generate_series(1, 4) as q")
		self._assertColumns(cols, [
			('integer', None, None, False)])
		self.assertEqual(cols[0][0], "q")

	def testSetGeneratingFunctionJoin(self):
		cols = self._getColSeq("select mass, q from generate_series(1, 4) as q"
			" join misc on (q=speed)")
		self._assertColumns(cols, [
			('real', 'kg', 'phys.mass', False),
			('integer', None, None, False)])
		self.assertEqual(cols[1][0], "q")


class CastColResTest(ColumnTest):
	def testSimple(self):
		cols = self._getColSeq("select dist,"
			" cast(dist as char(10)) as d_str,"
			" cast(dist as CHAR) as d_chr,"
			" cast(dist as natIonal char(13)) as d_uni,"
			" cast(dist as national char) as d_uni,"
			" cast(dist as integer) as d_int,"
			" cast(dist as bigint) as d_long,"
			" cast(ra1 as smallint) as ra_short,"
			" cast(dist as real) as d_float,"
			" cast(dist as double precision) as d_double,"
			" cast(dist as timestamp) as d_ts,"
			" cast(NULL as bigint) as n_long,"
			" cast(2 as smallint) as lit_int,"
			" cast(2 as varchar(*)) as varchar_lit"
			" from spatial")
		self._assertColumns(cols, [
			("real", 'm', 'phys.distance', False),
			("text", 'm', 'phys.distance', True),
			("char", 'm', 'phys.distance', True),
			("unicode", 'm', 'phys.distance', True),
			("unicode", 'm', 'phys.distance', True),
			("integer", 'm', 'phys.distance', True),
			("bigint", 'm', 'phys.distance', True),
			("smallint", 'deg', 'pos.eq.ra', True),
			("real", 'm', 'phys.distance', True),
			("double precision", 'm', 'phys.distance', True),
			("timestamp", 'm', 'phys.distance', True),
			("bigint", '', '', True),
			("smallint", '', '', True),
			("text", '', '', True),
			])
	
	def testPoint(self):
		cols = self._getColSeq("SELECT CAST('23 24' AS POINT),"
			"CAST(coverage AS CIRCLE),"
			"CAST('12 23' || '23 24' || '24 25' as polygon)"
			" from geo")
		self._assertColumns(cols, [
			("spoint", '', '', False),
			("scircle", '', '', False),
			("spoly", '', '', False),])
		

class WithColResTest(ColumnTest):
	def testSimple(self):
		cols = self._getColSeq("with hollow as (select * from spatial)"
			" select dist, ra2 from hollow")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "rad", "pos.eq.ra", False),])

	def testWithSetOperations(self):
		cols = self._getColSeq("with hollow as"
			" (select * from spatial where dist>5"
			"   union select * from spatial where ra1<3)"
			" select dist, ra2 from hollow")
		self._assertColumns(cols, [
			("real", "m", "phys.distance", False),
			("real", "rad", "pos.eq.ra", False),])

	def testMultipleWith(self):
		cols = self._getColSeq("with hollow as (select dist, ra2 from spatial),"
			"   filled as (select mass as dist, mag from misc)"
			" select ra2, mag from hollow natural join filled")
		self._assertColumns(cols, [
			("real", "rad", "pos.eq.ra", False),
			("real", "mag", "phot.mag", False)])


class _UploadTDWithOID(testhelpers.TestResource):
	def make(self, ignored):
		from gavo import votable
		from gavo.formats import votableread
		from io import StringIO

		rows =next(votable.parse(StringIO(
			"""<VOTABLE><RESOURCE><TABLE>
				<FIELD name="oid" datatype="float"/>
				<DATA><TABLEDATA><TR><TD>1</TD></TR></TABLEDATA></DATA>
				</TABLE></RESOURCE></VOTABLE>""")))
		return votableread.makeTableDefForVOTable(
			"foo", rows.tableDefinition, votableread.AutoQuotedNameMaker())

_uploadTDWithOID = _UploadTDWithOID()


class UploadColResTest(ColumnTest):
	resources = [("nastyTD", _uploadTDWithOID)]

	def setUp(self):
		ColumnTest.setUp(self)
		self.fieldInfoGetter = adqlglue.DaCHSFieldInfoGetter(tdsForUploads=[
			testhelpers.getTestTable("adql")])
	
	def testNormalResolution(self):
		cols = self._getColSeq("select alpha, rv from TAP_UPLOAD.adql")
		self._assertColumns(cols, [
			("real", 'deg', 'pos.eq.ra;meta.main', False),
			("double precision", 'km/s', 'phys.veloc;pos.heliocentric', False),])

	def testFailedResolutionCol(self):
		self.assertRaises(base.NotFoundError, self._getColSeq,
			'select alp, rv from TAP_UPLOAD.adql')
	
	def testFailedResolutionTable(self):
		self.assertRaises(base.NotFoundError, self._getColSeq,
			'select alpha, rv from TAP_UPLOAD.junk')

	def testPGForbiddenNames(self):
		self.fieldInfoGetter = adqlglue.DaCHSFieldInfoGetter(
			tdsForUploads=[self.nastyTD])
		cols = self._getColSeq(
			"select q.*, q.oid from (select oid from tap_upload.foo) as q")
		self.assertEqual(cols[0][0], "oid_")
		self.assertEqual(cols[1][0], "oid_")


class STCTest(ColumnTest):
	"""tests for working STC inference in ADQL expressions.
	"""
	def testSimple(self):
		cs = self._getColSeq("select ra1, ra2 from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'ICRS')
		self.assertEqual(cs[1][1].stc.astroSystem.spaceFrame.refFrame, 'FK4')

	def testBroken(self):
		cs = self._getColSeq("select ra1+ra2 from spatial")
		self.assertTrue(hasattr(cs[0][1].stc, "broken"))

	def testOKPoint(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point('ICRS', ra1, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'ICRS')
		self.assertEqual(ctx.errors, [])

	def testEmptyCoosysInherits(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point('', ra2, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'FK4')
		self.assertEqual(ctx.errors, [])

	def testEmptyCoosysBecomesNone(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point('', mag, 2) from misc")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, None)
		self.assertEqual(ctx.errors, [])

	def testMissingCoosysBecomesNone(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point('', mag, 2) from misc")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, None)
		self.assertEqual(ctx.errors, [])

	def testNULLCoosysInherits(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point(NULL, ra1, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'ICRS')
		self.assertEqual(ctx.errors, [])

	def testMissingCoosysInherits(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point(ra2, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'FK4')
		self.assertEqual(ctx.errors, [])

	def testMissingCoosysInheritsCircle(self):
		cs, ctx = self._getColSeqAndCtx(
			"select circle(ra2, dist, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'FK4')
		self.assertEqual(ctx.errors, [])

	def testPointBadCoo(self):
		cs, ctx = self._getColSeqAndCtx(
			"select point('ICRS', ra2, 2) from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame, 'ICRS')
		self.assertEqual(ctx.errors, ['When constructing point:'
			' Argument 1 has incompatible STC'])

	def testPointFunctionsSelect(self):
		cs, ctx = self._getColSeqAndCtx(
			"select coordsys(p), coord1(p), coord2(p) from"
			"	(select point('FK5', ra1, width) as p from spatial) as q")
		self._assertColumns(cs, [
			("text", '', 'meta.ref;pos.frame', False),
			("double precision", 'deg', None, False),
			("double precision", 'deg', None, False)])

	def testBadSTCSRegion(self):
		self.assertRaisesWithMsg(adql.RegionError,
			"Invalid argument to REGION: 'Time TT'.",
			self._getColSeqAndCtx, (
				"select * from spatial where 1=intersects("
				"region('Time TT'), circle('icrs', 1, 1, 0.1))",))

	def testRegionExpressionRaises(self):
		self.assertRaisesWithMsg(adql.RegionError,
			"Invalid argument to REGION: ''Position' || alphaName || deltaName'.",
			self._getColSeqAndCtx, (
				"select * from spatial where 1=intersects("
				"region('Position' || alphaName || deltaName),"
				" circle('icrs', 1, 1, 0.1))",))

	def testSTCSRegion(self):
		cs, ctx = self._getColSeqAndCtx(
				"select region('Circle FK4 10 10 1')"
				" from spatial")
		self.assertEqual(cs[0][1].unit, "deg")
	
	def testPolygonInheritsGeo(self):
		cs, ctx = self._getColSeqAndCtx(
				"select polygon(pt, point(1, 2), point(3,4))"
				" from geo")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame,
			'GALACTIC_II')

	def testPolygonInheritsSplit(self):
		cs, ctx = self._getColSeqAndCtx(
				"select polygon(ra2, dist, ra1, height, 2, 3)"
				" from spatial")
		self.assertEqual(cs[0][1].stc.astroSystem.spaceFrame.refFrame,
			'FK4')
	
	def timestampMetaInference(self):
		cs, ctx = self._getColSeqAndCtx(
			"select timestamp('2017-12-01'), TIMESTAMP(dt)"
			" from geo")
		self.assertEqual(cs[0][1].ucd, '')
		self.assertEqual(cs[0][1].unit, '')
		self.assertEqual(cs[0][1].type, 'timestamp')
		self.assertEqual(cs[0][1].stc, None)
		self.assertEqual(cs[1][1].ucd, 'time;obs')
		self.assertEqual(cs[1][1].type, 'timestamp')
		self.assertEqual(cs[1][1].stc.astroSystem.timeFrame.timeScale, "TT")


class FunctionNodeTest(unittest.TestCase):
	"""tests for nodes.FunctionMixin and friends.
	"""
	def setUp(self):
		self.grammar = adql.getGrammar()
	
	def testPlainArgparse(self):
		t = self.grammar.parseString("select POINT('ICRS', width,height)"
			" from spatial")[0]
		p = t.selectList.selectFields[0].expr
		self.assertEqual(p.cooSys, "ICRS")
		self.assertEqual(nodes.flatten(p.x), "width")
		self.assertEqual(nodes.flatten(p.y), "height")

	def testExprArgparse(self):
		t = self.grammar.parseString("select POINT('ICRS', "
			"5*width+height*LOG(width),height)"
			" from spatial")[0]
		p = t.selectList.selectFields[0].expr
		self.assertEqual(p.cooSys, "ICRS")
		self.assertEqual(nodes.flatten(p.x), "5 * width + height * LOG(width)")
		self.assertEqual(nodes.flatten(p.y), "height")


class ComplexExpressionTest(unittest.TestCase):
	"""quite random tests for correct processing of complex-ish search expressions.
	"""
	def testOne(self):
		t = adql.getGrammar().parseString("select top 5 * from"
			" lsw.plates where dateobs between 'J2416642 ' and 'J2416643'")[0]
		self.assertEqual(t.whereClause.children[1].name, "dateobs")
		self.assertEqual(adql.flatten(t.whereClause.children[-1]), "'J2416643'")


class NameSuggestingTest(testhelpers.VerboseTest,
		metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		query, name = sample
		t = adql.getGrammar().parseString(query)[0]
		self.assertEqual(t.suggestAName(), name)

	samples = [
		("select * from plaintable", "plaintable"),
		('select * from "plaintable"', "_plaintable"),
		('select * from "useless & table"" name"', "_uselesstablename"),
		('select * from "3columns"', "_3columns"),
		('select * from t1 join t2', "t1_t2"),
# 5
		('select * from (select * from gnug) as booger', "booger"),
		('select * from (select * from gnug) as booger join boog',
			"booger_boog"),]


class _FlatteningTest(testhelpers.VerboseTest):
	def _assertFlattensTo(self, rawADQL, flattenedADQL):
		tree = adql.parseToTree(rawADQL)
		self.assertEqual(adql.flatten(tree),
			flattenedADQL)


class MiscFlatteningTest(_FlatteningTest):
	"""tests for flattening of plain ADQL trees.
	"""
	def testCircle(self):
		self._assertFlattensTo("select alphaFloat, deltaFloat from ppmx.data"
				" where contains(point('ICRS', alphaFloat, deltaFloat), "
				" circle('ICRS', 23, 24, 0.2))=1",
			"SELECT alphaFloat, deltaFloat FROM ppmx.data WHERE"
				" CONTAINS(POINT(alphaFloat, deltaFloat),"
				" CIRCLE(ICRS,23,24,0.2)) = 1")

	def testFunctions(self):
		self._assertFlattensTo(
			"select round(x,2)as a, truncate(x,-2) as b from foo",
			"SELECT ROUND(x, 2) AS a, TRUNCATE(x, - 2) AS b FROM foo")

	def testJoin(self):
		self._assertFlattensTo(
			"SELECT ra1, dec, mass FROM\n"
			" (SELECT * FROM spatial) as q LEFT OUTER JOIN spatial2\n"
			" USING (ra1, dist) JOIN misc ON (dist=mass)",
			"SELECT ra1, dec, mass FROM (SELECT * FROM spatial) AS q"
			" LEFT OUTER JOIN spatial2 USING ( ra1 , dist ) JOIN misc"
			" ON ( dist = mass )")

	def testCommaJoin(self):
		self._assertFlattensTo(
			"SELECT ra1, dec, mass FROM\n spatial, spatial2, misc",
			"SELECT ra1, dec, mass FROM spatial , spatial2 , misc")

	def testSubJoin(self):
		self._assertFlattensTo(
			"SELECT ra1, dec, mass FROM\n"
			" (spatial join spatial2 using (ra1)), misc",
			"SELECT ra1, dec, mass FROM"
			" (spatial JOIN spatial2 USING ( ra1 )) , misc")

	def testConcat(self):
		self._assertFlattensTo(
			"select 'ivo://' ||  name || '%' as pat from crazy",
			"SELECT 'ivo://' || name || '%' AS pat FROM crazy")
	
	def testAliasExpr(self):
		self._assertFlattensTo(
			"select a+b/(8+x) as num from crazy",
			"SELECT a + b / ( 8 + x ) AS num FROM crazy")


class CommentTest(_FlatteningTest):
	def testTopComment(self):
		self._assertFlattensTo("-- opening remarks;\n"
		"-- quite a few of them, actually.\nselect * from foo",
			"SELECT * FROM foo")
	
	def testEmbeddedComments(self):
		self._assertFlattensTo("select -- comment\n"
			"bar, --comment\n"
			"quux --comment\n"
			"from -- comment\n"
			"foo --comment",
			"SELECT bar, quux FROM foo")

	def testStringJoining(self):
		self._assertFlattensTo("select * from bar where a='qua' -- cmt\n'tsch'",
			"SELECT * FROM bar WHERE a = 'quatsch'")
	
	def testLeadingWhitespaceCleanup(self):
		self._assertFlattensTo("select * from--comment\n   bar",
			"SELECT * FROM bar")
	
	def testEquivalentToWhitespace(self):
		self._assertFlattensTo("select * from--comment\nbar",
			"SELECT * FROM bar")


class _MorphTestWithTestTables(testhelpers.VerboseTest):
	def _parseAnnotating(self, query):
		return adql.parseAnnotating(query, _getFieldInfoGetter())[1]


class SpatialMorphTest(_MorphTestWithTestTables):
# These are tests exercising mainly the ugly optimisation hacks in nodes.py
# (_sortLargeFirst and friends).  They are mainly built on the _largeTable
# (spatial2) and _smallTable (spatial) annotations introduced in the field
# definitions above.
	def setUp(self):
		self.grammar = adql.getGrammar()

	def testCircleIn(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1, ra2 from spatial"
			" where contains(point('ICRS', ra1, ra2), "
				" circle('ICRS', 23, 24, 0.2))=1"))
		self.assertEqual(adql.flatten(t),
			"SELECT ra1, ra2 FROM spatial WHERE"
				" q3c_join(23, 24, ra1, ra2, 0.2)")
	
	def testCircleOut(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1, ra2 from spatial"
			" where 0=contains(point('ICRS', ra1, ra2),"
				" circle('ICRS', 23, 24, 0.2))"))
		self.assertEqual(adql.flatten(t),
			"SELECT ra1, ra2 FROM spatial WHERE"
				" NOT ( q3c_join(23, 24, ra1, ra2, 0.2) )")

	def testConstantsFirstPoint(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1 from spatial"
			" where 0=contains(point('ICRS', 23, 24),"
				" circle('ICRS', ra1, ra2, 0.2))"))
		self.assertEqual(adql.flatten(t),
			"SELECT ra1 FROM spatial WHERE"
				" NOT ( q3c_join(23, 24, ra1, ra2, 0.2) )")

	def testConstantsFirstCircle(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1 from spatial"
			" where 0=contains(point('ICRS', ra1, ra2),"
				" circle('ICRS', 23, 24, 0.2))"))
		self.assertEqual(adql.flatten(t),
			"SELECT ra1 FROM spatial WHERE"
				" NOT ( q3c_join(23, 24, ra1, ra2, 0.2) )")

	def testConstantsFirstDistanceDirect(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1 from spatial"
			" where distance(23, 24, ra1, ra2)<0.2"))
		self.assertEqual(adql.flatten(t),
			"SELECT ra1 FROM spatial WHERE"
				"  q3c_join(23, 24, ra1, ra2, 0.2)")

	def testConstantsFirstDistanceSwap(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1 from spatial"
			" where distance(ra1, ra2, 23, 24)<0.2"))
		self.assertEqual(adql.flatten(t),
			"SELECT ra1 FROM spatial WHERE"
				"  q3c_join(23, 24, ra1, ra2, 0.2)")

	def testCircleAnnotated(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("SELECT TOP 10 * FROM spatial"
			" WHERE 1=CONTAINS(POINT('ICRS', ra1, ra2),"
			"  CIRCLE('ICRS', 10, 10, 0.5))"))
		self.assertEqual(adql.flatten(t),
			"SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2 FROM spatial WHERE q3c_join(10, 10, ra1, ra2, 0.5) LIMIT 10")

	def testMogrifiedIntersect(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("SELECT TOP 10 * FROM spatial"
			" WHERE 1=INTERSECTS(CIRCLE('ICRS', 10, 10, 0.5),"
				"POINT('ICRS', ra1, ra2))"))
		self.assertEqual(adql.flatten(t),
			"SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2 FROM spatial WHERE q3c_join(10, 10, ra1, ra2, 0.5) LIMIT 10")

	def testDistanceTranslatedCrossmatch(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1 from spatial as a join spatial as b"
			" where distance(a.ra1, a.ra2, b.ra1, b.ra2)<0.001"))
		morphed = adql.flatten(t)
		self.assertEqual(morphed,
			"SELECT ra1 FROM spatial AS a JOIN spatial AS b"
			"  WHERE  q3c_join(a.ra1, a.ra2, b.ra1, b.ra2, 0.001)")

	def testOp2DistanceTranslatedCrossmatch(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1 from spatial as a join spatial as b"
			" where 0.001 > distance(a.ra1, a.ra2, b.ra1, b.ra2)"))
		morphed = adql.flatten(t)
		self.assertEqual(morphed,
			"SELECT ra1 FROM spatial AS a JOIN spatial AS b"
			"  WHERE  q3c_join(a.ra1, a.ra2, b.ra1, b.ra2, 0.001)")
	
	def testDistanceTranslatedInvertedCrossmatch(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating("select * from spatial as a join spatial as b"
			" where distance(a.ra1, a.ra2, b.ra1, b.ra2)>=0.001"))
		morphed = adql.flatten(t)
		self.assertEqual(morphed, "SELECT dist, width, height, ra1, ra2"
			" FROM spatial AS a JOIN spatial AS b  WHERE"
			" NOT  q3c_join(a.ra1, a.ra2, b.ra1, b.ra2, 0.001)")

	def testDistanceTranslatedSelect(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating(
				"select distance(ra1, ra2*2, ra2, dist) AS d from spatial"))
		self.assertEqual(adql.flatten(t),
			"SELECT DEGREES((spoint(RADIANS(ra1), RADIANS(ra2 * 2))) <-> (spoint(RADIANS(ra2), RADIANS(dist)))) AS d FROM spatial")

	def testDistanceTranslatedPGS(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating(
				"select distance(ra1, dec*2, ra1, dist) AS d from spatial2"))
		self.assertEqual(adql.flatten(t),
			"SELECT DEGREES((spoint(RADIANS(ra1), RADIANS(dist))) <->"
			" (spoint(RADIANS(ra1), RADIANS(dec * 2)))) AS d FROM spatial2")

	def testDistanceOptimisedOnTablesize(self):
		#	spatial2 is the larger table here and hence needs to be in the
		# second argument.
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1 from spatial as a join spatial2"
				" as b where distance(a.ra1, a.ra2, b.ra1, b.dec)<0.001"))
		morphed = adql.flatten(t)
		self.assertEqual(morphed, "SELECT ra1"
			" FROM spatial AS a JOIN spatial2 AS b  WHERE"
			"  q3c_join(a.ra1, a.ra2, b.ra1, b.dec, 0.001)")
	
	def testDistanceOptimisedOnTablesizeSwap(self):
		#	spatial2 is the larger table here and hence needs to be in the
		# second argument.
		s, t = morphpg.morphPG(
			self._parseAnnotating("select ra1 from spatial as a join spatial2"
				" as b where distance(b.ra1, b.dec, a.ra1, a.ra2)<0.001"))
		morphed = adql.flatten(t)
		self.assertEqual(morphed, "SELECT ra1"
			" FROM spatial AS a JOIN spatial2 AS b  WHERE"
			"  q3c_join(a.ra1, a.ra2, b.ra1, b.dec, 0.001)")

	def testDistanceOptimisedWithCTE(self):
		#	I sort CTEs into the first position no matter what, since they never
		# have q3c indexes.
		s, t = morphpg.morphPG(
			self._parseAnnotating("with sample as (select ra1, dec from spatial2)"
				" select ra1 from spatial as a join sample"
				" as b where distance(b.ra1, b.dec, a.ra1, a.ra2)<0.001"))
		morphed = adql.flatten(t)
		self.assertEqual(morphed, "WITH sample AS MATERIALIZED"
			" ( SELECT ra1, dec FROM spatial2 )"
			" SELECT ra1 FROM spatial AS a JOIN sample AS b  WHERE"
			"  q3c_join(b.ra1, b.dec, a.ra1, a.ra2, 0.001)")


class PQMorphTest(testhelpers.VerboseTest):
	"""tests for morphing to non-geometry ADQL syntax to postgres.
	"""
	resources = [("nastyTD", _uploadTDWithOID)]

	def _testMorph(self, stIn, stOut, fieldInfoGetter=None):
		tree = adql.parseToTree(stIn)
		if fieldInfoGetter:
			_ = adql.annotate(tree, fieldInfoGetter)
		status, t = adql.morphPG(tree)
		flattened = nodes.flatten(t)
		self.assertEqualIgnoringAliases(flattened, stOut)

	def testSyntax(self):
		self._testMorph("select distinct top 10 x, y from foo offset 3",
			'SELECT DISTINCT x, y FROM foo  LIMIT 10 OFFSET 3')

	def testTOP0(self):
		self._testMorph("select top 0 * from foo",
			'SELECT * FROM foo LIMIT 0')

	def testWhitespace(self):
		self._testMorph("select\t distinct top\n\r\n    10 x, y from foo",
			'SELECT DISTINCT x, y FROM foo LIMIT 10')
	
	def testGroupby(self):
		self._testMorph("select count(*), inc from ("
			" select round(x) as inc from foo) as q group by inc",
			"SELECT COUNT ( * ) ASWHATEVER, inc FROM"
			" (SELECT ROUND(x) AS inc FROM foo) AS q"
			" GROUP BY inc")

	def testTwoArgRound(self):
		self._testMorph(
			"select round(x, 2) as a, truncate(x, -2) as b from foo",
			'SELECT ROUND((x)*10^(2)) / 10^(2) AS a, TRUNC((x)*'
				'10^(- 2)) / 10^(- 2) AS b FROM foo')
	
	def testExprArgs(self):
		self._testMorph(
			"select truncate(round((x*2)+y, 4)) from foo",
			'SELECT TRUNC(ROUND((( x * 2 ) + y)*10^(4)) / 10^(4)) ASWHATEVER FROM foo')

	def testPointFunctionWithFieldInfo(self):
		t = adql.parseToTree("select coordsys(q.p) from "
			"(select point('ICRS', ra1, ra2) as p from spatial) as q")
		ctx = adql.annotate(t, _getFieldInfoGetter())
		self.assertEqual(ctx.errors[0],
			'When constructing point: Argument 2 has incompatible STC')
		status, t = adql.morphPG(t)
		self.assertEqualIgnoringAliases(nodes.flatten(t),
			"SELECT 'ICRS' ASWHATEVER FROM (SELECT spoint"
			"(RADIANS(ra1), RADIANS(ra2)) AS p FROM spatial) AS q")

	def testStringReplacedNumerics(self):
		self._testMorph("select square(x+x) from foo",
			"SELECT (x + x)^2 ASWHATEVER FROM foo")

	def testNumerics(self):
		self._testMorph("select log10(x), log(x), rand(), rand(5), "
			" TRUNCATE(x), TRUNCATE(x,3) from foo",
			'SELECT LOG(x) ASWHATEVER, LN(x) ASWHATEVER, random() ASWHATEVER,'
				' random() ASWHATEVER, TRUNC('
				'x) ASWHATEVER, TRUNC((x)*10^(3)) / 10^(3) ASWHATEVER FROM foo')

	def testHarmless(self):
		self._testMorph("select delta*2, alpha*mag, alpha+delta"
			" from something where mag<-10",
			'SELECT delta * 2 ASWHATEVER, alpha * mag ASWHATEVER, alpha + delta ASWHATEVER FROM something'
			' WHERE mag < - 10')

	def testUnaryLogic(self):
		self._testMorph("select x from something where y not in (1,2)",
			'SELECT x FROM something WHERE y NOT IN ( 1 , 2 )')

	def testOrder(self):
		self._testMorph("select top 100 * from spatial where dist>10"
			" order by dist, height",
			'SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2 FROM spatial WHERE dist > 10 ORDER BY dist , height LIMIT 100',
			_getFieldInfoGetter())

	def testUploadKilled(self):
		self._testMorph("select * from TAP_UPLOAD.abc",
			"SELECT * FROM abc")

	def testAliasedUploadKilled(self):
		self._testMorph("select * from TAP_UPLOAD.abc as o",
			"SELECT * FROM abc AS o")

	def testUploadColRef(self):
		self._testMorph("select TAP_UPLOAD.abc.c from TAP_UPLOAD.abc",
			"SELECT abc.c FROM abc")
	
	def testUploadColRefInGeom(self):
		self._testMorph("select POINT('', TAP_UPLOAD.abc.b, TAP_UPLOAD.abc.c)"
			" from TAP_UPLOAD.abc",
			"SELECT spoint(RADIANS(abc.b), RADIANS(abc.c)) ASWHATEVER FROM abc")

	def testUploadColRefInGeomContains(self):
		self._testMorph("SELECT TAP_UPLOAD.user_table.ra FROM"
			" TAP_UPLOAD.user_table WHERE (1=CONTAINS(POINT('ICRS',"
			" usnob.data.raj2000, usnob.data.dej2000), CIRCLE('ICRS',"
			" TAP_UPLOAD.user_table.ra2000, a.dec2000, 0.016666666666666666)))",
			"SELECT user_table.ra FROM user_table WHERE ("
			" ((spoint(RADIANS(usnob.data.raj2000), RADIANS(usnob.data.dej2000)))"
			" <@ (scircle(spoint(RADIANS(user_table.ra2000), RADIANS(a.dec2000)),"
			" RADIANS(0.016666666666666666)))) )")

	def testSTCSSingle(self):
		self._testMorph(
			"select * from foo where 1=CONTAINS(REGION('Position ICRS 1 2'), x)",
			"SELECT * FROM foo WHERE"
			" ((spoint '(0.0174532925,0.0349065850)') <@ (x))")

	def testSTCSExpr(self):
		self._testMorph(
			"select * from foo where 1=CONTAINS("
				"REGION('Union ICRS (Position 1 2 Intersection"
				" (circle  1 2 3 box 1 2 3 4 circle 30 40 2))'),"
				" REGION('circle GALACTIC 1 2 3'))",
				"SELECT * FROM foo WHERE ((spoint '(0.0174532925,0.0349065850)' <@ ((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >')+strans(1.346356097441,-1.097319001837,0.574770524729)))) OR (((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >' <@ ((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >')+strans(1.346356097441,-1.097319001837,0.574770524729)))) AND ((spoly '{(-0.0087266463,0.0000000000),(-0.0087266463,0.0698131701),(0.0436332313,0.0698131701),(0.0436332313,0.0000000000)}' <@ ((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >')+strans(1.346356097441,-1.097319001837,0.574770524729)))) AND ((scircle '< (0.5235987756, 0.6981317008), 0.0349065850 >' <@ ((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >')+strans(1.346356097441,-1.097319001837,0.574770524729)))))")
# TODO: Have a long, close look at this

	def testSTCSNotRegion(self):
		self._testMorph(
			"select * from foo where 1=INTERSECTS(REGION('NOT (circle  1 2 3)'), x)",
			"SELECT * FROM foo WHERE NOT ((scircle '< (0.0174532925, 0.0349065850), 0.0523598776 >' && (x)))")

	def testIsNotNull(self):
		self._testMorph(
			"select * from foo where x is not null",
			"SELECT * FROM foo WHERE x IS NOT NULL")

	def testIsNull(self):
		self._testMorph(
			"select * from foo where x is null",
			"SELECT * FROM foo WHERE x IS NULL")

	def testMultiJoin(self):
		self._testMorph(
			"select * from spatial natural join spatial2 join misc on (dist=speed)",
			"SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2, spatial2.ra1, spatial2.dec, spatial2.dist, spatial2.t, misc.mass, misc.mag, misc.speed FROM spatial NATURAL JOIN spatial2  JOIN misc ON ( dist = speed )",
			_getFieldInfoGetter())

	@unittest.skipUnless(testhelpers.hasUDF("IVO_EPOCH_PROP"), "pgsphere too old")
	def testMoveAndUnit(self):
		self._testMorph("select ivo_apply_pm("
			"in_unit(ra1, 'arcmin'), in_unit(dec, 'deg'),"
			"in_unit(ra1/t, 'mas/s'), in_unit(dec/(t+1), 'uarcsec/min'), 80) as gnack"
			" from spatial natural join spatial2 where"
			" 5<distance(point(ra1, dec),"
			" ivo_apply_pm(ra1, ra2, in_unit(dec/t, 'deg/yr'),"
			"   in_unit(ra1/t, 'mas/h'), 30))",
			"SELECT IVO_APPLY_PM((ra1 * 60), (dec * 1), ((ra1 / t) "
				"* 999.9999999999999), ((dec / ( t + 1 )) * 60000000), 80) AS gnack"
				" FROM spatial NATURAL JOIN spatial2  WHERE NOT "
				" (spoint(RADIANS(ra1), RADIANS(dec))) <@ scircle("
				"IVO_APPLY_PM(ra1, ra2, ((dec / t) * 8766), ((ra1 / t) "
				"* 3600000), 30), RADIANS(5))",
			_getFieldInfoGetter())

	def testQualifiedStar(self):
		self._testMorph(
			"select spatial.*, misc.* from spatial natural join spatial2"
				" join misc on (dist=speed)",
			"SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2, misc.mass, misc.mag, misc.speed FROM spatial NATURAL JOIN spatial2  JOIN misc ON ( dist = speed )",
			_getFieldInfoGetter())
	
	def testStarWithAlias(self):
		self._testMorph("select * from spatial as b",
			"SELECT b.dist, b.width, b.height, b.ra1, b.ra2 FROM spatial AS b",
			_getFieldInfoGetter())

	def testStarWithJoin(self):
		self._testMorph("select * from spatial join spatial2 on (width=dec)",
			"SELECT spatial.dist, spatial.width, spatial.height,"
			" spatial.ra1, spatial.ra2, spatial2.ra1, spatial2.dec,"
			" spatial2.dist, spatial2.t FROM spatial JOIN spatial2 ON ( width = dec )",
			_getFieldInfoGetter())

	def testStarWithSubquery(self):
		tree = adql.parseToTree("select * from spatial join "
			" (select ra1+dec, dist-2 as foo, dec from spatial2 offset 0) as q"
			" ON ( width = dec )")
		adql.annotate(tree, _getFieldInfoGetter())
		status, t = adql.morphPG(tree)
		flattened = nodes.flatten(t)
		self.assertTrue(re.match(r'SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2, q.([a-z]*), q.foo, q.dec FROM spatial JOIN \(SELECT ra1 \+ dec AS \1, dist - 2 AS foo, dec FROM spatial2  OFFSET 0\) AS q ON \( width = dec \)$', flattened))

	def testSetLimitIntegrated(self):
		self._testMorph("select top 3 * from x union (select top 40 a from y"
			" except select * from z)",
			"SELECT * FROM x UNION ( SELECT a FROM y EXCEPT SELECT * FROM z ) LIMIT 40")
	
	def testDeepSetLimitProtected(self):
		self._testMorph("select * from (select TOP 30 * from x) as q union"
			" select TOP 4 * from u",
			"SELECT * FROM (SELECT * FROM x LIMIT 30) AS q UNION SELECT * FROM u LIMIT 4")

	def testUCDCOL(self):
		# non postgres-specific, but we need the annotation
		self._testMorph("select UCDCOL('pos.eq.dec*'), UCDCOL('phys.dim')"
			" from spatial natural join spatial2",
			"SELECT dec, width FROM spatial NATURAL JOIN spatial2 ",
			_getFieldInfoGetter())

	def testInUnit(self):
		# non postgres-specific, but we need the annotation
		self._testMorph("select in_unit("
				"in_unit(ra1, 'deg')/in_unit(height, 'pc')+"
				"in_unit(ra2, 'deg')/in_unit(width, 'pc'), 'rad/pc') as fantasy"
				" from spatial",
			"SELECT (((ra1 * 1) / (height * 3.240755744239556e-14) + (ra2 * 57.29577951308232) / (width * 3.240755744239557e-17)) * 0.0174532925199433) AS fantasy FROM spatial",
			_getFieldInfoGetter())

	def testTimestamp(self):
		self._testMorph("select TIMESTAMP('2007-02-03' || 'T12:33:44') from geo"
			" where dt>TIMESTAMP('1997-03-02')",
			"SELECT ('2007-02-03' || 'T12:33:44')::TIMESTAMP ASWHATEVER FROM geo WHERE dt > ('1997-03-02')::TIMESTAMP")

	def testBitwise(self):
		self._testMorph("select BITWISE_AND(mass, mag) as one,"
			" BITWISE_OR(ROUND(mass), power(mag, 2)) as two,"
			" BITWISE_XOR(BITWISE_NOT(mass)+3, power(mag, 2)/2) as three"
			" from misc",
			"SELECT (mass)&(mag) AS one, (ROUND(mass))|(POWER(mag, 2)) AS two, (~(mass) + 3)#(POWER(mag, 2) / 2) AS three FROM misc")

	def testOidInUpload(self):
		self._testMorph(
			"select q.*, q.oid from (select oid from tap_upload.foo) as q",
			"SELECT q.oid_, q.oid_ FROM (SELECT oid_ FROM foo) AS q",
			fieldInfoGetter=adqlglue.DaCHSFieldInfoGetter(
			tdsForUploads=[self.nastyTD]))

	def testArray(self):
		self._testMorph("select vals[3] as x from crazy where vals[round(mass)]=3",
			"SELECT vals [ 3 ] AS x FROM crazy WHERE vals [ ROUND(mass) ] = 3",
			_getFieldInfoGetter())

	def testEmbeddedQuotes(self):
		self._testMorph("select count(*) as x from crazy where name='O''Toole'",
			"SELECT COUNT ( * ) AS x FROM crazy WHERE name = 'O''Toole'",
			_getFieldInfoGetter())

	def testSampling(self):
		self._testMorph("select * from foo as a tablesample( 0.01 )",
			"SELECT * FROM foo AS a TABLESAMPLE SYSTEM (0.01)")

	def testSetGenerating(self):
		self._testMorph("select * from generate_series(low, high) as foo",
			"SELECT * FROM generate_series ( low , high ) AS foo")

	def testCharStarInCast(self):
		self._testMorph("select cast(x as char(*)) as h from foo",
			"SELECT CAST ( x AS TEXT ) AS h FROM foo")

	def testCoalesce(self):
		self._testMorph("select coalesce(x, y+3, 'nothing') from foo",
			"SELECT COALESCE ( x , y + 3 , 'nothing' ) ASWHATEVER FROM foo")

	def testArrayAgg(self):
		self._testMorph("select arr_map(sin(x)+1, arr) as z from foo",
			"SELECT (SELECT array_agg( SIN(x) + 1 ) from unnest( arr ) x) AS z FROM foo")

	def testCastToPoint(self):
		self._testMorph("select cast('22' || ' -31' as pOint) from foo",
			"SELECT cast_to_point('22' || ' -31') ASWHATEVER FROM foo")

	def testCastToCircle(self):
		self._testMorph("select cast(MOC('4/22') as CIRCLE) as bcirc from foo",
			"SELECT cast_to_circle(smoc('4/22')) AS bcirc FROM foo")

	def testCastToPolygon(self):
		self._testMorph("select cast(CIRCLE(2, 3, 4) as polygon) from foo",
			"SELECT cast_to_polygon(scircle(spoint(RADIANS(2), RADIANS(3))"
				", RADIANS(4))) ASWHATEVER FROM foo")

	def testCreateTable(self):
		from gavo.web import tap_uploads
		try:
			self._testMorph(
				"Create TABLE tap_user.knall AS SELECT * FROM spatial",
				"CREATE TABLE tap_user._anonymous_knall AS  SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2 FROM spatial",
				_getFieldInfoGetter())
		finally:
			with base.getWritableTableConn() as conn:
				tap_uploads.dropUserUploadedTable(conn, "anonymous", "knall")


class PGSMorphTest(testhelpers.VerboseTest,
		metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		query, morphed = sample
		_, tree = adql.parseAnnotating(query, _getFieldInfoGetter())
		#pprint(tree.asTree())
		status, t = adql.morphPG(tree)
		self.assertEqualIgnoringAliases(nodes.flatten(t), morphed)

	samples = [
		("select AREA(circle('ICRS', COORD1(p1), coord2(p1), 2)),"
				" DISTANCE(p1,p2), centroid(circle('ICRS', coord1(p1), coord2(p1),"
				" 3)) from (select point('ICRS', ra1, ra2) as p1,"
				"   point('ICRS', ra2, dist) as p2 from spatial) as q",
				"SELECT 3282.806350011744*AREA(scircle(spoint(RADIANS(DEGREES(long(p1))), RADIANS(DEGREES(lat(p1)))), RADIANS(2))) ASWHATEVER, DEGREES((p1) <-> (p2)) ASWHATEVER, @@(scircle(spoint(RADIANS(DEGREES(long(p1))), RADIANS(DEGREES(lat(p1)))), RADIANS(3))) ASWHATEVER FROM (SELECT spoint(RADIANS(ra1), RADIANS(ra2)) AS p1, spoint(RADIANS(ra2), RADIANS(dist)) AS p2 FROM spatial) AS q"),
		("select coord1(pt) from geo", 'SELECT DEGREES(long(pt)) ASWHATEVER FROM geo'),
		("select coord2(pt) from geo", 'SELECT DEGREES(lat(pt)) ASWHATEVER FROM geo'),
		("select coordsys(q.p) as c from (select point('ICRS', dist, width)"
			" as p from spatial) as q",
			"SELECT 'ICRS' AS c FROM (SELECT spoint(RADIANS(dist), RADIANS(width)) AS p FROM spatial) AS q"),
		("select ra1 from spatial where"
				" Intersects(circle('ICRS', ra1, ra2,"
				" height*height), polygon('ICRS', 1, 12, 3, 4, 5, 6, 7, 8))=0",
				"SELECT ra1 FROM spatial WHERE NOT ( ((scircle(spoint(RADIANS(ra1), RADIANS(ra2)), RADIANS(height * height))) && ((SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(1), RADIANS(12))), (1, spoint(RADIANS(3), RADIANS(4))), (2, spoint(RADIANS(5), RADIANS(6))), (3, spoint(RADIANS(7), RADIANS(8))) ORDER BY column1) as q(ind,p)))) )"),

# 5
		("select ra1 from spatial where"
				" contains(circle('ICRS', ra1, ra2,"
				" height*height), box('ICRS', dist, width, height, ra2))=0",
			"SELECT ra1 FROM spatial WHERE NOT ( ((scircle(spoint(RADIANS(ra1), RADIANS(ra2)), RADIANS(height * height))) <@ ((SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(dist)-RADIANS(height)/2, RADIANS(width)-RADIANS(ra2)/2)), (1, spoint(RADIANS(dist)-RADIANS(height)/2, RADIANS(width)+RADIANS(ra2)/2)), (2, spoint(RADIANS(dist)+RADIANS(height)/2, RADIANS(width)+RADIANS(ra2)/2)), (3, spoint(RADIANS(dist)+RADIANS(height)/2, RADIANS(width)-RADIANS(ra2)/2)) ORDER BY column1) as q(ind,p)))) )"),
		("select point('ICRS', cos(ra1)*sin(ra2), cos(ra2)*sin(ra1)),"
				" circle('ICRS', width, height, 25-dist*dist) from spatial",
			'SELECT spoint(RADIANS(COS(ra1) * SIN(ra2)), RADIANS(COS(ra2) * SIN(ra1))) ASWHATEVER, scircle(spoint(RADIANS(width), RADIANS(height)), RADIANS(25 - dist * dist)) ASWHATEVER FROM spatial'),
		("select POiNT('ICRS', 1, 2), CIRCLE('ICRS', 2, 3, 4),"
				" bOx('ICRS', 2 ,3, 4, 5), polygon('ICRS', 2, 3, 4, 5, 6, 7)"
				" from spatial",
			'SELECT spoint(RADIANS(1), RADIANS(2)) ASWHATEVER,'
			' scircle(spoint(RADIANS(2), RADIANS(3)), RADIANS(4)) ASWHATEVER,'
			' (SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(2)-RADIANS(4)/2, RADIANS(3)-RADIANS(5)/2)), (1, spoint(RADIANS(2)-RADIANS(4)/2, RADIANS(3)+RADIANS(5)/2)), (2, spoint(RADIANS(2)+RADIANS(4)/2, RADIANS(3)+RADIANS(5)/2)), (3, spoint(RADIANS(2)+RADIANS(4)/2, RADIANS(3)-RADIANS(5)/2)) ORDER BY column1) as q(ind,p)) ASWHATEVER,'
			' (SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(2), RADIANS(3))), (1, spoint(RADIANS(4), RADIANS(5))), (2, spoint(RADIANS(6), RADIANS(7))) ORDER BY column1) as q(ind,p)) ASWHATEVER FROM spatial'),
		("select Box('ICRS',ra1,ra2,dist*100,width*100)"
			"	from spatial where dist!=0 and width!=0",
			"SELECT (SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS(ra1)-RADIANS(dist * 100)/2, RADIANS(ra2)-RADIANS(width * 100)/2)), (1, spoint(RADIANS(ra1)-RADIANS(dist * 100)/2, RADIANS(ra2)+RADIANS(width * 100)/2)), (2, spoint(RADIANS(ra1)+RADIANS(dist * 100)/2, RADIANS(ra2)+RADIANS(width * 100)/2)), (3, spoint(RADIANS(ra1)+RADIANS(dist * 100)/2, RADIANS(ra2)-RADIANS(width * 100)/2)) ORDER BY column1) as q(ind,p)) ASWHATEVER FROM spatial WHERE dist != 0 AND width != 0"),
		("select * from spatial where 1=contains(point('fk4', 1,2),"
			" circle('Galactic',2,3,4))",
			"SELECT spatial.dist, spatial.width, spatial.height, spatial.ra1, spatial.ra2 FROM spatial WHERE (((spoint(RADIANS(2), RADIANS(3)))-strans(1.346356097441,-1.097319001837,0.574770524729)+strans(1.565186433367,-0.004859055280,-1.576368104353)) <@ (scircle(spoint(RADIANS(1), RADIANS(2)), RADIANS(4))))"),
# 10
		("select ra2 from spatial where contains(point('UNKNOWN', height,width),"
			" circle('Galactic',2,3,4))=1",
			"SELECT ra2 FROM spatial WHERE ((spoint(RADIANS(height), RADIANS(width))) <@ (scircle(spoint(RADIANS(2), RADIANS(3)), RADIANS(4))))"),
		("select dt from geo where 1=intersects(coverage,"
			"circle('icrs', 10, 10, 1))",
			"SELECT dt FROM geo WHERE ((coverage) && (scircle(spoint(RADIANS(10), RADIANS(10)), RADIANS(1))))"),
		("select \"plAin\" from quoted where 1=intersects(\"left-right\","
			"circle('icrs', 10, 10, 1))",
			"SELECT \"plAin\" FROM quoted WHERE ((\"left-right\") && (scircle(spoint(RADIANS(10), RADIANS(10)), RADIANS(1))))"),
		("select contains(coverage, circle('', 10, 10, 1)) from geo",
			"SELECT CONTAINS(coverage, scircle(spoint(RADIANS(10), RADIANS(10)), RADIANS(1))) ASWHATEVER FROM geo"),
		("select * from (select point(ra1, ra2) as p from spatial) as q"
			" where 1=intersects(circle(p, 0.1), circle(1,2,3))",
			"SELECT q.p FROM (SELECT spoint(RADIANS(ra1), RADIANS(ra2)) AS p FROM spatial) AS q WHERE ((scircle(p, RADIANS(0.1))) && (scircle(spoint(RADIANS(1), RADIANS(2)), RADIANS(3))))"),
# 15
		("select distance(ra1, dec, 12, 13) as d1,"
			" distance(point(ra1, dec), point(12, 13)) as d2"
			" from spatial2"
			" where distance(10, 13, ra1, dec) < 2",
			"SELECT DEGREES((spoint(RADIANS(ra1), RADIANS(dec))) <-> (spoint(RADIANS(12), RADIANS(13)))) AS d1, DEGREES((spoint(RADIANS(ra1), RADIANS(dec))) <-> (spoint(RADIANS(12), RADIANS(13)))) AS d2 FROM spatial2 WHERE  (spoint(RADIANS(ra1), RADIANS(dec))) <@ scircle(spoint(RADIANS(10), RADIANS(13)), RADIANS(2))"),
		("select 1 as c"
			" from spatial2"
			" where 2>=distance(10, 13, ra1, dec)",
			'SELECT 1 AS c FROM spatial2 WHERE  (spoint(RADIANS(ra1), RADIANS(dec))) <@ scircle(spoint(RADIANS(10), RADIANS(13)), RADIANS(2))'),
		("select 1 as c"
			" from spatial2"
			" where 2<distance(ra1, dec, 10, 13)",
			"SELECT 1 AS c FROM spatial2 WHERE NOT  (spoint(RADIANS(ra1), RADIANS(dec))) <@ scircle(spoint(RADIANS(10), RADIANS(13)), RADIANS(2))"),
		("select 1 as c"
			" from spatial2"
			" where distance(10, 13, ra1, dec)>cos(dec)",
			"SELECT 1 AS c FROM spatial2 WHERE NOT  (spoint(RADIANS(ra1), RADIANS(dec))) <@ scircle(spoint(RADIANS(10), RADIANS(13)), RADIANS(COS(dec)))"),
		("select 1 as c"
			" from spatial2"
			" where distance(10, 13, ra1, dec)>distance(3, 10, dist, dec)",
			'SELECT 1 AS c FROM spatial2 WHERE  (spoint(RADIANS(dist), RADIANS(dec))) <@ scircle(spoint(RADIANS(3), RADIANS(10)), RADIANS(DEGREES((spoint(RADIANS(dist), RADIANS(dec))) <-> (spoint(RADIANS(3), RADIANS(10))))))'),
# 20	
		("select 1 as c"
			" from spatial2"
			" where distance(ra1, dec, 10, 13)>2",
			"SELECT 1 AS c FROM spatial2 WHERE NOT  (spoint(RADIANS(ra1), RADIANS(dec))) <@ scircle(spoint(RADIANS(10), RADIANS(13)), RADIANS(2))"),
		("select MOC(6, p), moc('4/3,'|| dist) from"
			" (select point(ra1, ra2) as p, dist from spatial) as q",
			"SELECT smoc(6, p) ASWHATEVER, smoc('4/3,' || dist) ASWHATEVER FROM (SELECT spoint(RADIANS(ra1), RADIANS(ra2)) AS p, dist FROM spatial) AS q"),
			]


class PGSNoMorphTest(testhelpers.VerboseTest):
	def testPolygonNoCentroid(self):
		tree = adql.parseToTree(
			"select centroid(polygon('ICRS', 12, 13, 14, 15, 15, 17)) from foo")
		self.assertRaisesWithMsg(adql.MorphError,
			"Can only compute centroids of circles and points yet."
			"  Complain to make us implement other geometries faster.",
			adql.morphPG,
			(tree,))

	def testReferencedBoxNoCentroid(self):
		tree = parseWithArtificialTable(
			"select centroid(b) from (select"
				" box('', 1, 1, 2, 2) as b from spatial) as q")
		self.assertRaisesWithMsg(adql.MorphError,
			"Can only compute centroids of circles and points yet."
			"  Complain to make us implement other geometries faster.",
			adql.morphPG,
			(tree,))


class UDFMorphTest(_MorphTestWithTestTables):
	def testDistanceWithCoord(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating(
				"select dist from spatial where"
				"  2>distance(coord1(ivo_healpix_center(5, 100)),"
				"	coord2(ivo_healpix_center(5, 100)), ra1, ra2)"))
		self.assertEqual(adql.flatten(t),
			"SELECT dist FROM spatial WHERE  q3c_join(COORD1(center_of_healpix_nest(5, 100)), COORD2(center_of_healpix_nest(5, 100)), ra1, ra2, 2)")

	def testCircleWithCoord(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating(
				"select dist from spatial where"
				"  1=contains(point(ra1, ra2), "
				" CIRCLE(coord1(ivo_healpix_center(5, 100)),"
				"	coord2(ivo_healpix_center(5, 100)), 2))"))
		self.assertEqual(adql.flatten(t),
			"SELECT dist FROM spatial WHERE q3c_join(DEGREES(long(center_of_healpix_nest(5, 100))), DEGREES(lat(center_of_healpix_nest(5, 100))), ra1, ra2, 2)")

	@unittest.skipUnless(testhelpers.hasUDF("IVO_EPOCH_PROP"), "pgsphere too old")
	def testPolygonConst(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating(
				"select polygon(coord1(p1), coord2(p1), coord1(p2), coord2(p2),"
					"2,3) from (SELECT POINT(ra1, ra2) as p1, ivo_apply_pm(ra1, ra2,"
					" 1e-7, 2e-7, 10) as p2 from spatial) as q"))
		self.assertEqualIgnoringAliases(adql.flatten(t),
			"SELECT (SELECT spoly(q.p) FROM (VALUES (0, spoint(RADIANS("
			"DEGREES(long(p1))), RADIANS(DEGREES(lat(p1))))), "
			"(1, spoint(RADIANS(DEGREES(long(p2))), RADIANS(DEGREES(lat(p2))))),"
			" (2, spoint(RADIANS(2), RADIANS(3))) ORDER BY column1) as q(ind,p))"
			" ASWHATEVER FROM (SELECT spoint(RADIANS(ra1), RADIANS(ra2)) AS p1,"
			" IVO_APPLY_PM(ra1, ra2, 1e-7, 2e-7, 10) AS p2 FROM spatial) AS q")

	def testNestedUDFs(self):
		s, t = morphpg.morphPG(
			self._parseAnnotating(
				"select dist from spatial where"
				" 1=ivo_interval_overlaps(dist-1, dist+1,"
				" gavo_specconv(500, 'pm', 'm'), gavo_specconv(600, 'pm', 'm'))"))
		self.assertEqual(adql.flatten(t),
			"SELECT dist FROM spatial WHERE dist + 1 >= 1e-12 * ( 500 ) / 1 AND 1e-12 * ( 600 ) / 1 >= dist - 1 AND dist - 1 <= dist + 1 AND 1e-12 * ( 500 ) / 1 <= 1e-12 * ( 600 ) / 1")


class GlueTest(testhelpers.VerboseTest):
# Tests for some aspects of adqlglue
	def testAutoNull(self):
		td = adqlglue._getTableDescForOutput(
			parseWithArtificialTable("select * from crazy"))
		self.assertEqual(td.getColumnByName("ct").values.nullLiteral, "-2147483648")

	def testSpecifiedNull(self):
		td = adqlglue._getTableDescForOutput(
			parseWithArtificialTable("select * from crazy"))
		self.assertEqual(td.getColumnByName("wot").values.nullLiteral, "-1")

	def testSpecifiedNullOverridden(self):
		td = adqlglue._getTableDescForOutput(
			parseWithArtificialTable("select 2+wot from crazy"))
		self.assertEqual(td.columns[0].values.nullLiteral, '-9223372036854775808')

	def testPureByteaNotPromoted(self):
		td = adqlglue._getTableDescForOutput(
			parseWithArtificialTable("select wotb from crazy"))
		self.assertEqual(td.columns[0].values.nullLiteral, '254')
		self.assertEqual(td.columns[0].type, 'bytea')

	def testByteaInMultiplication(self):
# This probably behaviour that doesn't work with postgres anyway.
# Fix the whole unsignedByte mess by not linking it to bytea.
		td = adqlglue._getTableDescForOutput(
			parseWithArtificialTable("select 2*wotb from crazy"))
		self.assertEqual(td.columns[0].type, 'smallint')
		self.assertEqual(td.columns[0].values.nullLiteral, "-32768")


class TableheadTest(testhelpers.VerboseTest):
	def _assertTablehead(self, expr, tablehead):
		tree, selClause = adql.parseAnnotating(f"select {expr} from spatial",
			_getFieldInfoGetter())
		tableTrunk = adqlglue._getTableDescForOutput(selClause)
		self.assertEqual(tableTrunk.columns[0].tablehead, tablehead)

	def testDirectColumnReference(self):
		self._assertTablehead("ra1", "Raw RA")

	def testNoInventedTablehead(self):
		self._assertTablehead("dist", None)

	def testAsHonoured(self):
		self._assertTablehead("dist as foo", "foo")

	def testDelimitedAs(self):
		self._assertTablehead('dist as "some ""mad"" terms"', 'some "mad" terms')

	def testExpression(self):
		self._assertTablehead('power(dist, 3)', 'POWER(dist, 3)')

	def testLongExpression(self):
		self._assertTablehead('sqrt(power(dist, 3)+log10(width))',
			'SQRT(POWER(dist, 3) + LOG10(w…')


class QueryTest(testhelpers.VerboseTest):
	"""performs some actual queries to test the whole thing.
	"""
	resources = [("ds",  adqlTestTable), ("querier", adqlQuerier),
		("geomds", tresc.adqlGeoTable)]

	def setUp(self):
		testhelpers.VerboseTest.setUp(self)
		self.tableName = self.ds.tables["adql"].tableDef.getQName()

	def _assertFieldProperties(self, dataField, expected):
		for label, value in expected:
			self.assertEqual(getattr(dataField, label, None), value,
				"Data field %s:"
				" Expected %s for %s, found %s"%(dataField.name, repr(value),
					label, repr(getattr(dataField, label, None))))

	def runQuery(self, query, **kwargs):
		return self.querier.queryADQL(query, **kwargs)

	def testPlainSelect(self):
		res = self.runQuery(
			"select alpha, delta, NULL from %s where mag<0"%
			self.tableName)
		self.assertEqual(res.tableDef.id, self.tableName.split(".")[-1])

		row = testhelpers.pickSingle(res.rows)
		self.assertEqual(len(row), 3)
		self.assertEqual(row["alpha"], 290.125)
		raField, deField, nullField = res.tableDef.columns
		self._assertFieldProperties(raField, [("ucd", 'pos.eq.ra;meta.main'),
			("description", 'A sample RA'), ("unit", 'deg'),
			("tablehead", "Raw RA")])
		self._assertFieldProperties(deField, [("ucd", 'pos.eq.dec;meta.main'),
			("description", 'A sample Dec'), ("unit", 'deg'),
			("tablehead", None)])
		self._assertFieldProperties(nullField, [("type", "real"),
			("description", ' [ADQL: NULL]'), ("unit", '')])

	def testStarSelect(self):
		res = self.runQuery("select * from %s where mag<0"%
			self.tableName)
		self.assertEqual(len(testhelpers.pickSingle(res.rows)), 5)
		fields = res.tableDef.columns
		self._assertFieldProperties(fields[0], [("ucd", 'pos.eq.ra;meta.main'),
			("description", 'A sample RA'), ("unit", 'deg'),
			("tablehead", "Raw RA")])
		self._assertFieldProperties(fields[1], [("ucd", 'pos.eq.dec;meta.main'),
			("description", 'A sample Dec'), ("unit", 'deg'),
			("tablehead", None)])
		self._assertFieldProperties(fields[3], [
			("ucd", 'phys.veloc;pos.heliocentric'),
			("description", 'A sample radial velocity'), ("unit", 'km/s')])
		self._assertFieldProperties(fields[4], [
			("ucd", ''), ("description", ''), ("unit", '')])

	def testQualifiedStarSelect(self):
		res = self.runQuery("select %s.* from %s join %s as q1"
			" using (mag) where q1.mag<0"%(
			self.tableName, self.tableName, self.tableName))
		self.assertEqual(res.tableDef.id, "adql_q1")
		self.assertEqual(len(testhelpers.pickSingle(res.rows)), 5)
		fields = res.tableDef.columns
		self._assertFieldProperties(fields[0], [("ucd", 'pos.eq.ra;meta.main'),
			("description", 'A sample RA'), ("unit", 'deg'),
			("tablehead", "Raw RA")])

	def testNoCase(self):
		# will just raise an Exception if things are broken.
		self.runQuery("select ALPHA, DeLtA, MaG from %s"%self.tableName)
	
	def testDelimitedMapping(self):
		res = self.runQuery('select alpha from "test"."adql"')
		self.assertEqual(
			str(res.tableDef.columns[0].name),
			'alpha')
		self.assertEqual(
			res.tableDef.id,
			"_adql")

	def testDelimitedBadTableFails(self):
		# hack: the group by in the query below currently suppresses the use
		# of cursors which would make our error messages a good deal uglier.
		self.assertRaisesWithMsg(base.DBError,
		'relation "test.Adql" does not exist\nLINE 1: SELECT "alpha" FROM test."Adql" GROUP BY "alpha" LIMIT 20000\n                            ^\n',
		self.runQuery,
		('select "alpha" from test."Adql" group by "alpha"',))

	def testDelimitedNameForDerived(self):
		res = self.runQuery('select "SEL-ECT".alpha from'
			'(SELECT alpha FROM test.adql) "SEL-ECT"')
		self.assertEqual(
			res.tableDef.columns[0].name,
			'alpha')
		self.assertEqual(res.tableDef.id, '"SEL-ECT"')

	def testTainting(self):
		res = self.runQuery("select delta*2, alpha*mag, alpha+delta"
			" from %s where mag<-10"% self.tableName)
		f1, f2, f3 = res.tableDef.columns
		self._assertFieldProperties(f1, [("ucd", 'pos.eq.dec;meta.main'),
			("description", 'A sample Dec -- *TAINTED*: the value was operated'
				' on in a way that unit and ucd may be severely wrong'
				' [ADQL: delta * 2]'),
			("unit", 'deg')])
		self._assertFieldProperties(f2, [("ucd", ''),
			("description", 'This field has traces of: A sample RA;'
				' A sample magnitude -- *TAINTED*: the value was operated'
				' on in a way that unit and ucd may be severely wrong'
				' [ADQL: alpha * mag]'),
			("unit", 'deg*mag')])
		self._assertFieldProperties(f3, [("ucd", ''),
			("description", 'This field has traces of: A sample RA; A sample Dec'
				' -- *TAINTED*: the value was operated on in a way that unit and'
				' ucd may be severely wrong [ADQL: alpha + delta]'),
			("unit", 'deg')])

	def testTransformation(self):
		res = self.runQuery("select mag from %s where"
			" 1=contains(point('icrs', alpha, delta),"
			"   circle('galactic', 107,-47, 1))"%self.tableName)
		self.assertEqual(list(res)[0]["mag"], 10.25)
	
	def testGeometryInSelect(self):
		res = self.runQuery(
			"select rv, point('icrs', alpha, delta) as p, mag, alpha, delta,"
			" contains(point(alpha, delta), circle('', 3, 15, 1)) as c,"
			" contains(point('', alpha, delta), circle(3, 15, 2)) as c1,"
			" intersects(point(alpha, delta), circle(3, 15, 2)) as i1,"
			" intersects(circle(alpha, delta, 1.5), circle(3, 15, 1.5)) as i2,"
			" intersects(circle(alpha, delta, 0.5), circle(3, 15, 0.5)) as i3,"
			" circle('icrs', alpha, delta, 10) as ci,"
			" circle(delta, alpha, 5) as cn"
			" from %s where mag>5"%self.tableName)
		rows = list(res)

		expected = pgsphere.SPoint.fromDegrees(2, 14)
		self.assertAlmostEqual(rows[0]["p"].x, expected.x)
		self.assertAlmostEqual(rows[0]["p"].y, expected.y)
		self.assertEqual(rows[0]["rV"], -23.75)
		self.assertEqual(rows[0]["c"], 0)
		self.assertEqual(rows[0]["c1"], 1)
		self.assertEqual(rows[0]["i1"], 1)
		self.assertEqual(rows[0]["i2"], 1)
		self.assertEqual(rows[0]["i3"], 0)
		expected = pgsphere.SCircle.fromDALI([2, 14, 10])
		expected2 = pgsphere.SCircle.fromDALI([14, 2, 5])
		self.assertAlmostEqual(rows[0]["ci"].center.x, expected.center.x)
		self.assertAlmostEqual(rows[0]["ci"].radius, expected.radius)
		self.assertAlmostEqual(rows[0]["cn"].center.y, expected2.center.y)
		self.assertAlmostEqual(rows[0]["cn"].radius, expected2.radius)
		self.assertEqual(res.tableDef.getColumnByName("p").type,
			"spoint")
		self.assertEqual(res.tableDef.getColumnByName("ci").type,
			"scircle")
		self.assertEqual(res.tableDef.getColumnByName("cn").type,
			"scircle")

	def testQuotedIdentifier(self):
		res = self.runQuery(
			'select "rv", rV from %s where delta=89'%self.tableName)
		self.assertEqual(res.rows, [{"rV": 28., "rV_": 28.}])
		
	def testNoCasefixingOnExpressions(self):
		res = self.runQuery(
			'select rv, rv+1 from %s where delta=89'%self.tableName)
		self.assertEqual(res.rows[0]["rV"], 28.0)
		# the other name must be a result of intToFunnyWord, i.e., all-lowercase.
		otherId = (set(res.rows[0].keys())-{"rV"}).pop()
		self.assertTrue(re.match("[a-z]+$", otherId))

	def testDistanceDegrees(self):
		res = self.runQuery(
			"select DISTANCE(POINT('ICRS', 22, -3), POINT('ICRS', 183, 50)) as d"
			" from %s"%self.tableName)
		self.assertAlmostEqual(res.rows[0]["d"], 130.31777623681)

	def testInUnitGeometry(self):
		res = self.runQuery(
			"select in_unit(DISTANCE(POINT('ICRS', 22, -3),"
				" POINT('ICRS', 183, 50)), 'rad') as d"
			" from %s"%self.tableName)
		self.assertAlmostEqual(res.rows[0]["d"], 2.27447426920956)

	@unittest.skipUnless(testhelpers.hasUDF("IVO_EPOCH_PROP"), "pgsphere too old")
	def testApplyPM(self):
		res = self.runQuery(
			"SELECT moved, coord1(moved) as ra, coord2(moved) as dec FROM (SELECT"
			" ivo_epoch_prop_pos(alpha, delta, NULL, 7200, -3600, NULL, 2020, 1965) as moved"
			" FROM %s where alpha between 20 and 26) AS q"%self.tableName)
		mapped = list(base.SerManager(res).getMappedValues())[0]
		self.assertAlmostEqual(mapped["moved"][0],  24.88665977523)
		self.assertAlmostEqual(mapped["moved"][1], -13.9449737974)
		self.assertAlmostEqual(mapped["ra"],  24.88665977523)
		self.assertAlmostEqual(mapped["dec"], -13.9449737974)
	
	def testStringFunctions(self):
		res = self.runQuery(
			"SELECT UPPer(table_name) as tn, lower(table_name) as td"
				" from tap_schema.columns where column_name ='table_name'")
		self.assertTrue({'tn': 'TAP_SCHEMA.TABLES', 'td': 'tap_schema.tables'
			} in res.rows)

	def testDMAnnotation(self):
		res = self.runQuery("SELECT TOP 1 * FROM %s"%self.tableName)
		ann = next(res.tableDef.iterAnnotationsOfType("geojson:FeatureCollection"
			))
		self.assertEqual(ann["feature"]["geometry"]["type"], "sepcoo")
		self.assertEqual(ann["feature"]["geometry"]["latitude"].value,
			res.tableDef.getByName("delta"))

	def testMOCVsPoint(self):
		res = self.runQuery("SELECT row_id FROM test.adqlgeo"
			" WHERE 1=CONTAINS(POINT('ICRS', 55.5, 20.7), a_moc)")
		self.assertEqual([r["row_id"] for r in res.rows],
			['moc-6-1', 'moc-4-1'])
		res = self.runQuery("SELECT row_id FROM test.adqlgeo"
			" WHERE 1=CONTAINS(POINT('ICRS', 56, 21), a_moc)")
		self.assertEqual([r["row_id"] for r in res.rows],
			['moc-4-1'])

	def testSpointedCircle(self):
		res = self.runQuery("SELECT row_id FROM test.adqlgeo"
			" WHERE 1=INTERSECTS(CIRCLE(a_point, 1), CIRCLE(23, 42, 1))")
		self.assertEqual([r["row_id"] for r in res.rows],
			['moc-6-1'])
	
	def testSpointedPolygon(self):
		res = self.runQuery("SELECT row_id,"
			" polygon(a_point,POINT(23,42), POINT(23, 41)) as p FROM test.adqlgeo"
			" WHERE 1=CONTAINS(POINT(24, 43), POLYGON(a_point,"
			" POINT(23,42), POINT(23, 41), POINT(24, 41)))")
		self.assertEqual([r["row_id"] for r in res.rows],
			['moc-4-1'])
		self.assertAlmostEqual(res.rows[-1]["p"].points[0].x, 0.436332313)
	
	def testCTE(self):
		res = self.runQuery("WITH knall as (SELECT POINT(alpha, delta) as pt"
			" from test.adql)"
			" SELECT * FROM kNall")
		self.assertEqual(len(res.rows), 3)
		self.assertEqual(type(res.rows[0]["pt"]), pgsphere.SPoint)
		self.assertEqual(res.tableDef.columns[0].name, "pt")
		for infoMeta in res.iterMeta("info"):
			if infoMeta.infoName=="sql_query":
				self.assertEqual(infoMeta.infoValue,
					'WITH knall AS MATERIALIZED ( SELECT spoint(RADIANS(alpha), RADIANS(delta)) AS pt FROM test.adql ) SELECT kNall.pt FROM kNall LIMIT 20000')
			break
		else:
			self.fail("No sql_query info meta?")

	def testSetGeneratingFunction(self):
		res = self.runQuery("SELECT * FROM GENERATE_SERIES(1,2)")
		self.assertEqual(len(res.rows), 2)
		self.assertEqual(res.tableDef.columns[0].name, "generate_series")
		self.assertEqual(res.tableDef.id, "generate_series")

	def testMOCMaking(self):
		res = self.runQuery("SELECT moc('4/' || z) as g"
			" FROM GENERATE_SERIES(9,12) as z")
		self.assertEqual(res.rows[-1]["g"].asDALI(), "4/12")

	def testTimestampOutput(self):
		res = self.runQuery("select top 1 timestamp('2020-01-12')"
			" as t from tap_schema.tables")
		self.assertEqual(res.rows[0]["t"], datetime.datetime(2020, 1, 12))

	def testGeoCasts(self):
		res = self.runQuery("select"
			" cast(cast(alpha as char(*)) || ' ' || cast(delta as char(*))"
			"   as point) as pt,"
			" cast('20 -10 ' || rv as circle) as ci,"
			" cast(cast('20 -10 14' as circle) as point) as cpt,"
			" cast('10 80 120 40 35 -10' as polygon) as poly"
			" from test.adql where mag<0")
		# alpha': 290.125, 'delta': 89.0, 'mag': -1.0, 'rV': 28.0
		row = res.rows[0]
		self.assertEqual(res.tableDef.getColumnByName("pt").ucd, "")
		self.assertAlmostEqual(row["pt"].asDALI()[0], 290.125)
		self.assertAlmostEqual(row["ci"].asDALI()[-1], 28)
		self.assertAlmostEqual(row["cpt"].asDALI()[1], -10)
		self.assertAlmostEqual(row["poly"].asDALI()[4], 35)

	def testToplevelSetOperation(self):
		res = self.runQuery("select alpha from test.adql where delta<80"
			" except select alpha from test.adql where delta<0")
		self.assertEqual(res.tableDef.columns[0].name, "alpha")
		self.assertEqual(res.rows, [{'alpha': 2.0}])


class SimpleSTCSTest(testhelpers.VerboseTest):
	def setUp(self):
		self.parse = tapstc.getSimpleSTCSParser()

	def testPosParses(self):
		res = self.parse("Position 10 20 ")
		self.assertEqual(res.pgType, "spoint")
		self.assertAlmostEqual(res.x, 0.174532925199432)
		self.assertEqual(res.cooSys, "UNKNOWN")
	
	def testCircleParses(self):
		res = self.parse(" Circle ICRS 10 20 1e0")
		self.assertEqual(res.pgType, "scircle")
		self.assertEqual(res.cooSys, "ICRS")

	def testBadCircleRaises(self):
		self.assertRaisesWithMsg(stc.STCSParseError,
			'STC-S circles want three numbers.',
			self.parse,
			("Circle 2 1",))

	def testBoxParses(self):
		res = self.parse("box TOPOCENTER SPHERICAL2 -10  20 2.1 5.4")
		self.assertEqual(res.pgType, "spoly")
	
	def testPolyParses(self):
		res = self.parse("PolyGon FK4 TOPOCENTER SPHERICAL2 -10  20 2.1 5.4 1 3")
		self.assertEqual(res.pgType, "spoly")

	def testNotParses(self):
		res = self.parse("NOT  (Box ICRS 1 2 3 4)")
		self.assertTrue(isinstance(res, tapstc.GeomExpr))
		self.assertAlmostEqual(
			testhelpers.pickSingle(res.operands).points[0].x,
			-0.00872664626)
		self.assertEqual(res.cooSys, "UNKNOWN")
	
	def testSimpleOpParses(self):
		res = self.parse("UNiON (Box ICRS 1 2 3 4 Circle 1 2 3)")
		self.assertTrue(isinstance(res, tapstc.GeomExpr))
		self.assertEqual(res.operator, "UNION")
		self.assertEqual(len(res.operands), 2)
		self.assertEqual(res.operands[0].pgType, "spoly")
		self.assertEqual(res.operands[1].pgType, "scircle")
		self.assertEqual(res.cooSys, "UNKNOWN")
	
	def testComplexOpParses(self):
		res = self.parse("INtersection FK4 ("
			"UNiON BARYCENTER (Box ICRS 1 2 3 4 Circle 1 2 3)"
			" Polygon ICRS GEOCENTER 2 3 4 5 6 7"
			" Circle Fk4 spherical2 3 4 5)")
		self.assertEqual(res.operands[0].operator, "UNION")
		self.assertEqual(res.operands[1].cooSys, "ICRS")
		self.assertEqual(res.cooSys, "FK4")

	def testCartesianRaises(self):
		self.assertRaisesWithMsg(stc.STCSParseError,
			'Only SPHERICAL2 STC-S supported here',
			self.parse,
			("Position CARTESIAN3 1 2 3",))


class IntersectsFallbackTest(testhelpers.VerboseTest):
# Does INTERSECT fall back to CONTAINS?
	def testArg1(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT pt from geo where intersects(pt, circle('ICRS', 2, 2, 1))=1",
			_getFieldInfoGetter())
		funNode = tree.whereClause.children[1].op1
		self.assertEqual(funNode.funName, "CONTAINS")
		self.assertEqual(funNode.args[0].type, "columnReference")
		self.assertEqual(funNode.args[1].type, "circle")

	def testArg2(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT pt from geo where intersects(circle('ICRS', 2, 2, 1), pt)=1",
			_getFieldInfoGetter())
		funNode = tree.whereClause.children[1].op1
		self.assertEqual(funNode.funName, "CONTAINS")
		self.assertEqual(funNode.args[0].type, "columnReference")
		self.assertEqual(funNode.args[1].type, "circle")

	def testExpr(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT pt from geo where intersects("
			"point('ICRS', 2, 2), circle('ICRS', 2, 2, 1))=1",
			_getFieldInfoGetter())
		funNode = tree.whereClause.children[1].op1
		self.assertEqual(funNode.funName, "CONTAINS")
		self.assertEqual(funNode.args[0].type, "point")
	
	def testNotouch(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT pt from geo where intersects("
			"box('ICRS', 2, 2, 3, 3), circle('ICRS', 2, 2, 1))=1",
			_getFieldInfoGetter())
		funNode = tree.whereClause.children[1].op1
		self.assertEqual(funNode.funName, "INTERSECTS")

	def testJoinCond(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT * from geo as a join geo as b on (intersects("
			"circle('ICRS', coord1(b.pt), coord2(b.pt), 1), a.pt)=1)",
			_getFieldInfoGetter())
		funNode = tree.fromClause.tableReference.joinSpecification.children[2].op1
		self.assertEqual(funNode.funName, "CONTAINS")
		self.assertEqual(funNode.args[0].fieldInfo.type, "spoint")

	def testJoinCondGeoCol(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT * from geo as a join geo as b on (intersects("
			"circle('ICRS', b.pt, 1), a.pt)=1)",
			_getFieldInfoGetter())
		funNode = tree.fromClause.tableReference.joinSpecification.children[2].op1
		self.assertEqual(funNode.funName, "CONTAINS")
		self.assertEqual(funNode.args[0].fieldInfo.type, "spoint")


class AnnotationTest(testhelpers.VerboseTest):
	def testFailing(self):
		self.assertRaisesWithMsg(adql.AmbiguousColumn,
			"ra1",
			adql.parseAnnotating,
			("SELECT ra1 from spatial WHERE EXISTS (SELECT 1 FROM spatial2 where"
			" ra1=t)",
			_getFieldInfoGetter()))

	def testSucceedingUniqueMatch(self):
		ctx, tree = adql.parseAnnotating(
			"SELECT ra1 from spatial as a WHERE EXISTS (SELECT 1 FROM spatial where"
			" ra1=a.height)",
			_getFieldInfoGetter())

	def testFailingMultiInference(self):
		class MultiField(nodes.TransparentNode, nodes.FieldInfoedNode):
			type = "multifield"
		mf = MultiField(children=[
			nodes.GenericValueExpression(children=[4]),
			nodes.GenericValueExpression(children=[5])])

		self.assertRaisesWithMsg(adql.Error,
			"More than one child with fieldInfo with no behaviour defined in"
			" MultiField, children [('children', <ADQL Node genericValueExpression>),"
			" ('children', <ADQL Node genericValueExpression>)]",
			fieldinfos._annotateNodeRecurse,
			(mf, annotations.AnnotationContext(_getFieldInfoGetter())))

	def testFailingNullInference(self):
		class MultiField(nodes.TransparentNode, nodes.FieldInfoedNode):
			type = "multifield"
		mf = MultiField(children=[])

		self.assertRaisesWithMsg(adql.Error,
			"No child with fieldInfo with no behaviour defined in"
			" MultiField, children []",
			fieldinfos._annotateNodeRecurse,
			(mf, annotations.AnnotationContext(_getFieldInfoGetter())))


if __name__=="__main__":
	testhelpers.main(STCTest)
