"""
Unit-ish tests for our SIAP2 infrastructure.
"""

#c Copyright 2008-2024, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


import functools

from gavo.helpers import testhelpers

from gavo import api
from gavo import rscdef
from gavo.grammars import fitsprodgrammar
from gavo.helpers import trialhelpers
from gavo.protocols import siap

import tresc


@functools.cache
def _getTestTable():
	td = api.parseFromString(api.TableDef,
		"""<table id="sia2">
			<FEED source="//obscore#obscore-columns"/>
			</table>""")
	td.parent = testhelpers.getTestRD()
	td.setMetaParent(testhelpers.getTestRD())
	return td


def _getRawdictFor(fitsPath):
	"""returns fitsPath as parsed by a FITSProdGrammar with minimal
	products#define keys added in artificially.
	"""
	row = next(iter(_getRawdictFor.grammar.parse(fitsPath)))
	row.update({
		"prodtblAccref": fitsPath,
		"prodtblPath": fitsPath,
		"prodtblPreview": None,
		"prodtblFsize": 65536,
		"prodtblTable": "test.sia2",
		"prodtblMime": "application/fits",})
	return row
_getRawdictFor.grammar = api.makeStruct(fitsprodgrammar.FITSProdGrammar)


def _getMappedMetaRow(applyStuff, fitsPath="test_data/ex.fits"):
	rmd = api.parseFromString(rscdef.RowmakerDef,
		f"""<rowmaker><apply procDef="//siap2#setMeta" name="romp">
			<bind name="obs_collection">"Testers"</bind>
			{applyStuff}
			</apply>
			<apply procDef="//siap2#getBandFromFilter"/></rowmaker>""")
	mapper = rmd.compileForTableDef(_getTestTable())
	return mapper(_getRawdictFor(fitsPath), _getTestTable())


class SetMetaTest(testhelpers.VerboseTest):
	def testCollectionRequired(self):
		self.assertRaisesWithMsg(api.StructureError,
		"At IO:'<rowmaker><apply procDef=\"//siap2#setMeta\" name=\"romp...', (1, 56): Parameter obs_collection is not defaulted in romp and thus must be bound.",
		api.parseFromString,
		(rscdef.RowmakerDef,
			"""<rowmaker><apply procDef="//siap2#setMeta" name="romp"/></rowmaker"""))

	def testSetMetaDefaults(self):
		row = _getMappedMetaRow("")
		self.assertEqual(row["source_table"], "test.sia2")
		self.assertEqual(row["dataproduct_type"], "image")
		self.assertEqual(row["t_xel"], 1)
		self.assertEqual(row["obs_publisher_did"],
			"ivo://x-testing/~?test_data/ex.fits")
		self.assertEqual(row["o_ucd"], "phot.count")
		self.assertEqual(row["target_class"], None)
		self.assertEqual(row["instrument_name"], None)
		self.assertEqual(row["facility_name"],
			"Markus' Proving Grounds")
		self.assertEqual(row["access_estsize"], 64)
		self.assertEqual(row["access_url"], "test_data/ex.fits")

	def testObsDateComputation(self):
		row = _getMappedMetaRow(
			'<bind name="dateObs">23000</bind>'
			'<bind name="t_exptime">43200</bind>')
		self.assertEqual(row["t_min"], 22999.75)
		self.assertEqual(row["t_max"], 23000.25)

	def testObsDateNoOverride(self):
		row = _getMappedMetaRow(
			'<bind name="dateObs">23000</bind>'
			'<bind name="t_exptime">43200</bind>'
			'<bind name="t_min">22900</bind>')
		self.assertEqual(row["t_min"], 22900)
		self.assertEqual(row["t_max"], 23000.25)

	def testBandMapping(self):
		row = _getMappedMetaRow(
			'<bind name="bandpassId">"Johnson B"</bind>')
		self.assertEqual(row["em_min"], 3.7e-7)
		self.assertEqual(row["em_max"], 5.5e-7)

	def testBadBandMapping(self):
		row = _getMappedMetaRow(
			'<bind name="bandpassId">"Jacobson B"</bind>')
		self.assertEqual(row["em_min"], None)
		self.assertEqual(row["em_max"], None)


class ComputePGSTest(testhelpers.VerboseTest):
	def testCoverageComputation(self):
		rmd = api.parseFromString(rscdef.RowmakerDef,
			"""<rowmaker><apply procDef="//siap2#computePGS" name="romp">
				</apply></rowmaker>""")
		mapper = rmd.compileForTableDef(_getTestTable())
		row = mapper(_getRawdictFor("test_data/ex.fits"), _getTestTable())
		self.assertAlmostEqual(row["s_ra"], 168.2454770094)
		self.assertAlmostEqual(row["s_resolution"], 1.0164694847)
		self.assertAlmostEqual(row["s_pixel_scale"], 1.0164694847)
		self.assertAlmostEqual(row["s_xel1"], 12)
		self.assertAlmostEqual(row["s_fov"], 0.00695060466)
		self.assertAlmostEqual(row["s_region"].asCooPairs()[0][0], 168.2470850301)

	def testNoWCSDefault(self):
		rmd = api.parseFromString(rscdef.RowmakerDef,
			"""<rowmaker><apply procDef="//siap2#computePGS" name="romp">
				</apply></rowmaker>""")
		mapper = rmd.compileForTableDef(_getTestTable())
		inRow = _getRawdictFor("test_data/ex.fits")
		del inRow["CD1_1"]
		self.assertRaisesWithMsg(api.ValidationError,
			"Field romp: While executing romp in <rowmaker without id>: No WCS information",
			mapper,
			(inRow, _getTestTable()))

	def testNoWCSIgnored(self):
		rmd = api.parseFromString(rscdef.RowmakerDef,
			"""<rowmaker><apply procDef="//siap2#computePGS" name="romp">
				<bind key="missingIsError">False</bind>
				</apply></rowmaker>""")
		mapper = rmd.compileForTableDef(_getTestTable())
		inRow = _getRawdictFor("test_data/ex.fits")
		del inRow["CD1_1"]
		row = mapper(inRow, _getTestTable())
		self.assertEqual(row["s_ra"], None)
		self.assertEqual(row["s_xel1"], 12)


class _SIAP2ImportedTable(tresc.RDDataResource):
	rdName = "data/siap2test"
	dataId = "import"


class SIAP2SingleTableTest(testhelpers.VerboseTest):
	resources = [("data", _SIAP2ImportedTable())]

	def _doQuery(self, params):
		return trialhelpers.runSvcWith(
			self.data.tableDef.rd.getById("svc"),
			"siap2.xml",
			params).getPrimaryTable()

	def testServiceNoParams(self):
		tab = self._doQuery({})
		row = testhelpers.pickSingle(tab.rows)
		nones = {k for k,v in row.items() if v is None}

		self.assertEqual(nones, {'target_class', 'em_ucd', 't_resolution',
			'pol_states', 'obs_creator_did', 'target_name', 'em_max',
			'instrument_name', 'pol_xel', 'em_min', 'em_res_power'})
		self.assertEqual(row["calib_level"], 2)
		self.assertEqual(row["extracolumn"], 5)
		self.assertEqual(row["access_format"], "application/fits")
		# this is un-value-mapped at this point; the full URI is built during
		# serialisation
		self.assertEqual(row["access_url"],
			"data/ex.fits")
		self.assertEqual(row['obs_publisher_did'],
			'ivo://x-testing/~?data/ex.fits')
		# the next assertion we should kick when obs_id is sanitised.
		self.assertEqual(row['obs_id'], 'ivo://x-testing/~?data/ex.fits')

		self.assertEqual(row["s_xel1"], 12)
		self.assertAlmostEqual(row["s_ra"], 168.2454770094)
		self.assertAlmostEqual(row["s_pixel_scale"], 1.016469484705153)

	def testParamOverNullColumn(self):
		tab = self._doQuery({"TARGET": [" "]})
		self.assertEqual(len(tab.rows), 0)


class SIAP2GeometryStringTest(testhelpers.VerboseTest):
	def testEmpty(self):
		self.assertRaisesWithMsg(api.ValidationError,
			"Field POS: Invalid SIAPv2 geometry: '' (expected a SIAPv2 shape name)",
			siap.parseSIAP2Geometry,
			("",))

	def testBadShape(self):
		self.assertRaisesWithMsg(api.ValidationError,
			"Field POS: Invalid SIAPv2 geometry: 'Trash 13 14 1 4'"
			" (expected a SIAPv2 shape name)",
			siap.parseSIAP2Geometry,
			("Trash 13 14 1 4",))

	def testBadCoo(self):
		self.assertRaisesWithMsg(api.ValidationError,
			"Field POS: Invalid SIAPv2 coordinates: 'depp 12 13'"
			" (bad floating point literal)",
			siap.parseSIAP2Geometry,
			("CIRCLE depp 12 13",))

	def testGoodCircle(self):
		res = siap.parseSIAP2Geometry("CIRCLE 143 82 13")
		self.assertEqual(res.asSTCS("Unknown"),
			"Circle Unknown 143. 82. 13.")

	def testBadCircle(self):
		self.assertRaisesWithMsg(api.ValidationError,
			"Field POS: Invalid SIAPv2 CIRCLE: 'CIRCLE 12 13'"
			" (need exactly three numbers)",
			siap.parseSIAP2Geometry,
			("CIRCLE 12 13",))

	def testGoodRange(self): # as if there  were such a thing
		res = siap.parseSIAP2Geometry("RANGE 345 355 -13 13")
		self.assertEqual(res.asSTCS("Unknown"),
			"PositionInterval Unknown 345. -13. 355. 13.")

	def testBadRange1(self):
		self.assertRaisesWithMsg(api.ValidationError,
			"Field POS: Invalid SIAPv2 RANGE: 'RANGE 345 355 -13'"
				" (need exactly four numbers)",
			siap.parseSIAP2Geometry,
			("RANGE 345 355 -13",))

	def testBadRange2(self):
		self.assertRaisesWithMsg(api.ValidationError,
			"Field POS: Invalid SIAPv2 RANGE: 'RANGE 345 355 13 -13'"
				" (lower limits must be smaller than upper limits)",
			siap.parseSIAP2Geometry,
			("RANGE 345 355 13 -13",))

	def testGoodPolygon(self): # as if there  were such a thing
		res = siap.parseSIAP2Geometry("POLYGON 12 13 34 -34 35 12")
		self.assertEqual(res.asSTCS("Unknown"),
			"Polygon Unknown 12. 13. 34. -34. 35. 12.")

	def testBadPolygon(self):
		self.assertRaisesWithMsg(api.ValidationError,
			"Field POS: Invalid SIAPv2 POLYGON: '12 13 34 -34 35 1...' (need"
				" more than three coordinate *pairs*)",
			siap.parseSIAP2Geometry,
			("POLYGON 12 13 34 -34 35 12 22.3290032",))


class SIAP2ServiceTest(testhelpers.VerboseTest):
	resources = [("data", tresc.siapTestTable),
		# we want spectra in the table so we can make sure they don't come back
		# in a query below.
		("spectra", tresc.ssaTestTable),
		("obscore", tresc.obscoreTable)]

	def _doQuery(self, params):
		return trialhelpers.runSvcWith(
			api.resolveCrossId("//siap2#sitewide"),
			"siap2.xml",
			params
		).getPrimaryTable()

	def testBasicCooQuery(self):
		res = self._doQuery({"POS": ["CIRCLE 4 44 0.5"]})
		row = testhelpers.pickSingle(res.rows)
		self.assertEqual(row["access_estsize"], 81)
		self.assertEqual(row["access_url"],
			"image/(1, 45)/(4.1, 1.1)")
		for info in res.getMeta("info"):
			if info.infoName=="queryPars":
				self.assertEqual(info.getContent(),
					"{'pos0': <pgsphere Circle Unknown 4. 44. 0.5>}")
				self.assertEqual(info.infoValue,
					"(s_region &&%(pos0)s) AND (dataproduct_type in ('image', 'cube'))")
				break

	def testDualCooQuery(self):
		res = self._doQuery({"POS": ["CIRCLE 4 44 0.5", "RANGE 250 260 89 90"]})
		for info in res.getMeta("info"):
			if info.infoName=="queryPars":
				self.assertEqual(info.infoValue,
					'((s_region &&%(pos0)s) OR (s_region &&%(pos1)s))'
					" AND (dataproduct_type in ('image', 'cube'))")
				self.assertEqual(info.getContent(),
					"{'pos0': <pgsphere Circle Unknown 4. 44. 0.5>, 'pos1':"
					" <pgsphere\nPositionInterval Unknown 250. 89. 260. 90.>}")
				break
		self.assertEqual(len(res.rows), 2)

	def testBANDQuery(self):
		res = self._doQuery({"BAND": ["-Inf 0.5", "40 60", "165 +Inf"]})
		self.assertEqual(len(res.rows), 5)
		for info in res.getMeta("info"):
			if info.infoName=="queryPars":
				self.assertEqual(info.infoValue,
					'((%(BAND1)s >= em_min AND em_max >= %(BAND0)s)'
					' OR (%(BAND3)s >= em_min AND em_max >= %(BAND2)s)'
					' OR (%(BAND5)s >= em_min AND em_max >= %(BAND4)s))'
					" AND (dataproduct_type in ('image', 'cube'))")

	def testCombinedQuery(self):
		self.maxDiff = None
		res = self._doQuery({"BAND": ["-Inf 0.5", "40 60"],
			"POS": ["CIRCLE 4 44 0.5", "RANGE 250 260 89 90"]})
		for info in res.getMeta("info"):
			if info.infoName=="queryPars":
				self.assertEqual(info.infoValue,
					'((s_region &&%(pos0)s) OR (s_region &&%(pos1)s)) AND'
					' ((%(BAND1)s >= em_min AND em_max >= %(BAND0)s)'
					' OR (%(BAND3)s >= em_min AND em_max >= %(BAND2)s))'
					" AND (dataproduct_type in ('image', 'cube'))")
		self.assertEqual(len(res.rows), 1)

	def testDPConstraint(self):
		res = self._doQuery({"INSTRUMENT": ["DaCHS test suite"]})
		# this would return the ssaptest spectra if not for the empty constraint
		self.assertEqual(len(res.rows), 0)
		# just make sure ssaptest actually is in ivoa.obscore
		spectra = list(self.spectra.connection.query(
			"SELECT * FROM ivoa.obscore"
			" WHERE instrument_name='DaCHS test suite'"))
		if len(spectra)!=6:
			import pdb;pdb.Pdb(nosigint=True).set_trace()
		self.assertEqual(len(spectra), 6)

	def testDPOverride(self):
		res = self._doQuery({"INSTRUMENT": ["cube"]})
		# TODO: Add some cubes at some point
		self.assertEqual(len(res.rows), 0)
	
	def testFOV(self):
		res = self._doQuery({"FOV": ["1 2.5", "11 +Inf"]})
		self.assertEqual(len(res.rows), 6)
	
	def testTIME(self):
		res = self._doQuery({"TIME": ["-Inf 55200", "55300.5 55302.3"]})
		self.assertEqual(len(res.rows), 2)
	
	def testPOLPositive(self):
		res = self._doQuery({"POL": ["RR", "LL"]})
		self.assertEqual(len(res.rows), 9)

	def testPOLNegative(self):
		res = self._doQuery({"POL": ["I", "Q", "U", "LR", "X"]})
		self.assertEqual(len(res.rows), 0)


if __name__=="__main__":
	testhelpers.main(ComputePGSTest)
