"""
Tests to do with obtaining and managing table- and column level
metadata (in user.info and vicinity).
"""


#c Copyright 2008-2024, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


from gavo.helpers import testhelpers

from gavo import api
from gavo import base
from gavo.helpers import testtricks
from gavo.user import info
from gavo.user import limits

import tresc


class LimitsAnnotatorQueryGenTest(testhelpers.VerboseTest,
		metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		colType, expected = sample
		td = api.parseFromString(api.TableDef,
			'<table id="c">%s</table>'%colType)
		td.parent = base.caches.getRD("//tap")
		ofs, _ = info.getAnnotators(td)
		self.assertEqual([o.select for o in ofs], expected)

	samples = [('<column name="x" type="double precision"/>', [
			"MAX(x)",
			"MIN(x)",
			'percentile_cont(ARRAY[0.03, 0.5, 0.97]) WITHIN GROUP (ORDER BY x)',
			'AVG(CASE WHEN x IS NULL THEN 0 ELSE 1 END)']),
		('<column name="quoted/yikes" type="bigint"/>', [
			'MAX("yikes")',
			'MIN("yikes")',
			'percentile_cont(ARRAY[0.03, 0.5, 0.97]) WITHIN GROUP (ORDER BY "yikes")',
			'AVG(CASE WHEN "yikes" IS NULL THEN 0 ELSE 1 END)']),
		('<column name="x" type="char"><values nullLiteral="u"/></column>', [
			"MAX(x)",
			"MIN(x)",
			'AVG(CASE WHEN x IS NULL THEN 0 ELSE 1 END)']),
		('<column name="p" type="spoly"/>',
			['AVG(CASE WHEN p IS NULL THEN 0 ELSE 1 END)']),
		('<column name="x" type="char"><property key="statistics">'
			'no</property></column>', []),
# 05
		('<column name="x" type="text"><property key="statistics">'
			'enumerate</property></column>', [
				'MAX(x)',
				'MIN(x)',
				"(SELECT jsonb_object_agg(COALESCE(val, 'NULL'), ct) FROM ( SELECT x AS val, count(*) AS ct "
				'FROM tap_schema.c GROUP BY x) AS subquery_x)',
				'AVG(CASE WHEN x IS NULL THEN 0 ELSE 1 END)'
			]),
	]


class _AnnotatedTable(testhelpers.TestResource):
	resources = [("inputTable", tresc.ssaTestTable)]

	def make(self, deps):
		td = deps["inputTable"].tableDef
		info.annotateDBTable(td)
		return td


class DefaultSampleSizeTest(testhelpers.VerboseTest,
		metaclass=testhelpers.SamplesBasedAutoTest):
	def _runTest(self, sample):
		nrows, percent = sample
		self.assertEqual(info.getDefaultSamplePercent(nrows), percent)

	# Yeah, this test is a bit silly given we're just computing an
	# expression; consider this a representation of my gut feeling on
	# how the curve ought to look like.
	samples = [
		(18, 100),
		(180, 100),
		(180000, 100),
		(1e6, 100),
		(1e7, 50),
		(1e8, 26),
		(1e9, 16),
		(1e10, 11),
		(1e100, 1),
	]


class LimitsAnnotatorQueryTest(testhelpers.VerboseTest):
	resources = [("td", _AnnotatedTable())]

	def testNrowsAnnotated(self):
		self.assertEqual(self.td.nrows, 6)

	def testStringAnnotated(self):
		ann = self.td.getColumnByName("accref").annotations
		self.assertEqual(set(ann.keys()),
			{'max_value', 'min_value', 'fill_factor'})
		self.assertEqual(ann["max_value"], "data/spec3.ssatest.vot")
		self.assertEqual(ann["min_value"], "data/spec1.ssatest")
		self.assertEqual(ann["fill_factor"], 1.0)

	def testIntAnnotated(self):
		ann = self.td.getColumnByName("accsize").annotations
		self.assertEqual(set(ann.keys()),
			{'max_value', 'min_value', 'fill_factor', "percentile03",
				"median", "percentile97"})
		self.assertEqual(ann["max_value"], 225)
		self.assertEqual(ann["percentile03"], 213.0)
		self.assertEqual(ann["median"], 219.0)
		self.assertEqual(ann["percentile97"], 225.0)

	def testFillFactor(self):
		ann = self.td.getColumnByName("ssa_location").annotations
		self.assertAlmostEqual(ann["fill_factor"], 2/3.)

	def testSpointAnnotated(self):
		ann = self.td.getColumnByName("ssa_location").annotations
		self.assertEqual(set(ann.keys()), {'fill_factor'})

	def testAllNull(self):
		ann = self.td.getColumnByName("ssa_length").annotations
		self.assertEqual(set(ann.keys()),
			{'fill_factor'})
		self.assertEqual(ann["fill_factor"], 0)

	def testStringEnumeration(self):
		ann = self.td.getColumnByName("ssa_bandpass").annotations
		self.assertTrue('discrete_values' in ann.keys())
		self.assertAlmostEqual(ann["discrete_values"]["B"], 1/3)
		self.assertAlmostEqual(ann["discrete_values"]["R"], 1/3)
		self.assertAlmostEqual(ann["discrete_values"]["V"], 1/3)


class LimitsTest(testhelpers.VerboseTest):
	def testLimitsComputing(self):
		td = testhelpers.getTestRD().getById("valSpec").copy(None)
		table = api.TableForDef(td,
			rows=[
				{"a_num": 13, "enum": None},
				{"a_num": None, "enum": "bad"},
				{"a_num": -10, "enum": "horrific"}])
		lims = table.getLimits()

		self.assertEqual(lims["a_num"].min, -10)
		self.assertEqual(lims["a_num"].max, 13)
		self.assertEqual(lims["a_num"].values, None)

		self.assertEqual(lims["enum"].max, None)
		self.assertEqual(lims["enum"].min, None)
		self.assertEqual(lims["enum"].values, set(["bad", "horrific"]))

	def testSCSCoverageQuery(self):
		self.assertEqual(
			info.getMOCQuery(testhelpers.getTestRD().getById("adql"), 6),
			"SELECT smoc('6/' || string_agg(format('%%s', hpx), ','))\nFROM (\n  SELECT DISTINCT healpix_nest(6, spoint(RADIANS(alpha), RADIANS(delta))) AS hpx \nFROM test.adql\nWHERE alpha IS NOT NULL AND delta IS NOT NULL\nGROUP BY hpx\n) as q")

	def testSSAPCoverageQuery(self):
		self.assertEqual(
			info.getMOCQuery(api.resolveCrossId("data/ssatest#mixctest"), 3),
			"SELECT SUM(SMOC(3,\n  SCIRCLE(ssa_location, RADIANS(COALESCE(ssa_aperture, 1/3600.)))))\n"
			"FROM test.mixctest WHERE ssa_location IS NOT NULL")

	def testSIAPCoverageQuery(self):
		self.assertEqual(
			info.getMOCQuery(api.resolveCrossId("data/test#pgs_siaptable"), 9),
			"SELECT SUM(SMOC(9, coverage))\n"
			"FROM test.pgs_siaptable WHERE coverage IS NOT NULL")

	def testObscoreCoverageQuery(self):
		self.assertEqual(
			info.getMOCQuery(api.resolveCrossId("//obscore#ObsCore"), 6),
			"SELECT SUM(coverage)\n"
			"FROM (SELECT\n"
			"  COALESCE(\n"
			"    SMOC(6, s_region),"
			" smoc_disc(6, RADIANS(s_ra), RADIANS(s_dec), RADIANS(s_fov)),\n"
			"    NULL) AS coverage\n"
			"  FROM ivoa.ObsCore\n"
			"  ) AS q")

	def testNonstandardTableCoverage(self):
		self.assertRaisesWithMsg(api.ReportableError,
			"Table data/test#typesTable does not have columns DaCHS"
			" knows how to get a coverage from.",
			info.getMOCQuery,
			(api.resolveCrossId("data/test#typesTable"), 3))

	def testSIAPLimitsColumns(self):
		td = api.resolveCrossId("data/test#pgs_siaptable")
		lo, hi, trans = info.getTimeLimitsExprs(td)
		self.assertEqual(lo, 'MIN(dateObs)')
		self.assertEqual(hi, 'MAX(dateObs)')
		self.assertAlmostEqual(trans(54000), 54000)

		lo, hi, trans = info.getSpectralLimitsExprs(td)
		self.assertEqual(hi, 'MIN(bandpassLo)')
		self.assertEqual(lo, 'MAX(bandpassHi)')
		self.assertAlmostEqual(trans(10), 1.9864458241717583e-26)

	def testSSAPLimitsColumns(self):
		td = api.resolveCrossId("data/ssatest#mixctest")
		lo, hi, trans = info.getTimeLimitsExprs(td)
		self.assertEqual(lo, 'MIN(ssa_dateObs)')
		self.assertEqual(hi, 'MAX(ssa_dateObs)')
		self.assertAlmostEqual(trans(54000), 54000)

		lo, hi, trans = info.getSpectralLimitsExprs(td)
		self.assertEqual(hi, 'MIN(ssa_specstart)')
		self.assertEqual(lo, 'MAX(ssa_specend)')
		self.assertAlmostEqual(trans(10), 1.9864458241717583e-26)

	def testMissingLimits(self):
		self.assertRaisesWithMsg(api.NotFoundError,
			"Columns to figure out 'spectral coverage' could not be located in table data/test#typesTable",
			info.getSpectralLimitsExprs,
			(api.resolveCrossId("data/test#typesTable"),))

		self.assertRaisesWithMsg(api.NotFoundError,
			"Columns to figure out 'temporal coverage' could not be located in table data/test#typesTable",
			info.getTimeLimitsExprs,
			(api.resolveCrossId("data/test#typesTable"),))


class _ColstatsContent(testhelpers.TestResource):
	resources = [("inputTable", tresc.ssaTestTable)]

	def make(self, deps):
		inputTable = deps["inputTable"]
		tableName = inputTable.tableDef.getQName()

		# insert an artificial sentinel so we can figure out
		# if updateTableLevelStats actually dropped existing data.
		with base.getWritableAdminConn() as conn:
			conn.execute("INSERT INTO dc.simple_col_stats"
				" (tableName, column_name)"
				" VALUES (%(tableName)s, 'sentinel_garbage')",
				locals())
			conn.execute("INSERT INTO dc.discrete_string_values"
				" (tableName, column_name)"
				" VALUES (%(tableName)s, 'sentinel_garbage')",
				locals())

		limits.updateTableLevelStats(
			inputTable.tableDef, inputTable.conn)

		numStats = dict((r["column_name"], r)
			for r in inputTable.connection.queryToDicts(
				"SELECT * FROM dc.simple_col_stats"
				" WHERE tableName='test.hcdtest'"))
		enumStats = dict((r["column_name"], r)
			for r in inputTable.connection.queryToDicts(
				"SELECT * FROM dc.discrete_string_values"
				" WHERE tableName='test.hcdtest'"))

		return {"num": numStats, "enum": enumStats}


class ColstatsTableTest(testhelpers.VerboseTest):
	resources = [("colStats", _ColstatsContent())]

	def testPreviousContentsRemoved(self):
		self.assertFalse("sentinel_garbage" in self.colStats["num"])
		self.assertFalse("sentinel_garbage" in self.colStats["enum"])

	def testTableDeclared(self):
		self.assertEqual([r["tablename"] for r in self.colStats["num"].values()],
			["test.hcdtest"]*13)

	def testAllNumericColumnsAnnotated(self):
		self.assertEqual(len(self.colStats["num"]), 13)

	def testFillFactorNonNull(self):
		self.assertTrue(
			None not in set(r["fill_factor"] for r in self.colStats["num"].values()))

	def testFillFactorComputed(self):
		self.assertAlmostEqual(self.colStats["num"]["ssa_snr"]["fill_factor"], 0.)
		self.assertAlmostEqual(
			self.colStats["num"]["ssa_dateObs"]["fill_factor"], 1.)

	def testMedianComputation(self):
		self.assertEqual(self.colStats["num"]["accsize"]["median"], "219.0")
	
	def testPercentiles(self):
		self.assertAlmostEqual(
			float(self.colStats["num"]["ssa_redshift"]["percentile03"]),
			-0.001),
		self.assertAlmostEqual(
			float(self.colStats["num"]["ssa_redshift"]["percentile97"]),
			0.7)

	def testMinMax(self):
		self.assertEqual(self.colStats["num"]["accsize"]["min_value"], "213")
		self.assertEqual(self.colStats["num"]["accsize"]["max_value"], "225")

	def testBandpassDistribution(self):
		row = self.colStats["enum"]["ssa_bandpass"]
		self.assertEqual(row["tablename"], "test.hcdtest")
		self.assertEqual(row["vals"], ["B", "R", "V"])
		self.assertAlmostEqual(row["freqs"][0], 1/3)
	
	def testTargetDistribution(self):
		row = self.colStats["enum"]["ssa_targname"]
		self.assertEqual(row["vals"],
			['big fart nebula', 'booger star', 'rat hole in the yard'])


class _TDWithStats(testhelpers.TestResource):
	"""The stuff from the tstat RD with the in-DB annotations.

	This takes care to tear down the statistics entries again, as
	they might interfere with other tests.
	"""
	resources = [("conn", tresc.dbConnection)]

	def make(self, deps):
		conn = deps["conn"]
		rdId = "data/tstat"
		rd = api.getRD(rdId)

		api.makeData(rd.getById("import"), connection=conn)
		conn.commit()
		limits.updateTableLevelStats(
			rd.getById("DoNotMixCase"), conn)

		base.caches.clearForName(rdId)
		return api.resolveCrossId(rdId+"#DoNotMixCase")

	def clean(self, ignored):
		with base.getWritableAdminConn() as conn:
			conn.execute("delete from dc.simple_col_stats where"
				" tableName='test.DoNotMixCase'")


_tdWithStats = _TDWithStats()

class ColstatsInjectionTest(testhelpers.VerboseTest):
	resources = [("statsTD", _tdWithStats)]

	def testFloatAnnotationInValues(self):
		col = self.statsTD.getColumnByName("plain_float")
		self.assertEqual(col.values.min, "-1")
		self.assertEqual(col.values.max, "1.0")
		self.assertAlmostEqual(float(col.values.median), 0.41666666, places=6)

	def testDelimitedAnnotation(self):
		col = self.statsTD.getColumnByName(api.QuotedName("gg-/l"))
		self.assertAlmostEqual(float(col.values.percentile03), 1.27, places=6)
		self.assertEqual(col.values.median, "6.5")
		self.assertAlmostEqual(float(col.values.percentile97), 15.37, places=6)
		self.assertEqual(col.values.fillFactor, "1.0")

	def testManuallyAnnotatedColumn(self):	
		col = self.statsTD.getColumnByName("plain_float")
		self.assertEqual(col.values.min, "-1")
		col = self.statsTD.getColumnByName("dontlook")
		self.assertEqual(col.values.min, "-3.6")

	def testStatisticsSkip(self):
		col = self.statsTD.getColumnByName("dontlook")
		self.assertEqual(col.values.fillFactor, None)

	def testNoneIsSkipped(self):
		col = self.statsTD.getColumnByName("allnull")
		self.assertEqual(col.values.median, None)

	def testStringSel(self):
		col = self.statsTD.getColumnByName("hurgel")
		self.assertEqual(
			[opt.title for opt in col.values.options],
			["nma0", "nma1"])


class _VOSITablesetWithStats(testhelpers.TestResource):
	"""a VOSIDataService tableset over a statistically annotated table.
	"""
	resources = [("statsTD", _tdWithStats)]

	def make(self, deps):
		from gavo.registry import tableset
		from gavo.web.vosi import VTM

		xmlLit = tableset.getTableForTableDef(
			deps["statsTD"],
			set(),
			rootElement=VTM.table).render()
		tree = testhelpers.getXMLTree(xmlLit, debug=False)
		return xmlLit, tree


class ColstatsVSTest(testhelpers.VerboseTest, testtricks.XSDTestMixin):
	resources = [("rt", _VOSITablesetWithStats())]

	@staticmethod
	def _N(attName):
		return f"{{http://dc.g-vo.org/ColStats-1}}{attName}"

	def testValid(self):
		self.assertValidates(self.rt[0])

	def testMinValueDefinedPlain(self):
		self.assertEqual(
			self.rt[1].xpath("column[name='plain_float']")[0].get(
				self._N("min-value")),
			"-1")

	def testMinValueDefinedDelimited(self):
		self.assertEqual(
			self.rt[1].xpath("column[name='\"gg-/l\"']")[0].get(
				self._N("max-value")),
			"16")

	def testMedianDefined(self):
		self.assertEqual(
			self.rt[1].xpath("column[name='\"gg-/l\"']")[0].get(
				self._N("median")),
			"6.5")

	def testLowerPercentileDefined(self):
		self.assertAlmostEqual(
			float(self.rt[1].xpath("column[name='plain_float']")[0].get(
				self._N("percentile03"))),
			0.2575)

	def testUpperPercentileDefined(self):
		self.assertAlmostEqual(
			float(self.rt[1].xpath("column[name='plain_float']")[0].get(
				self._N("percentile97"))),
			0.955)

	def testFillFactorDefined(self):
		self.assertEqual(
			self.rt[1].xpath("column[name='plain_float']")[0].get(
				self._N("fillFactor")),
			"1")


class DumpingTest(testhelpers.VerboseTest):
	resources = [
		("conn", tresc.dbConnection),
		("adqlgeo", tresc.adqlGeoTable),
		("withstats", _tdWithStats)]

	def testUnimportedTable(self):
		rd = base.parseFromString(api.RD,
			"""<resource schema="test"><table id="dumb" onDisk="True">
			<column name="x"/></table></resource>""")
		testtricks.assertHasStrings(
			testhelpers.captureOutput(
				limits.dumpTableLevelStats, (rd.getById("dumb"), self.conn))[1],
			["Statistics for test.dumb", "No metadata (table not imported)?"])

	def testImportedTable(self):
		testtricks.assertHasStrings(
			testhelpers.captureOutput(
				limits.dumpTableLevelStats, (self.adqlgeo.tableDef, self.conn))[1],
			["Statistics for test.ADQLgeo", "|rows| = <Unknown>"])
	
	def testAnalysedTable(self):
		testtricks.assertHasStrings(
			testhelpers.captureOutput(
				limits.dumpTableLevelStats, (self.withstats, self.conn))[1],
			["Statistics for test.DoNotMixCase", "|rows| = 4",
			"plain_float", '"gg-/l"', " 0.257500000… ", " 0.955", " None "])

	def testRDInfo(self):
		testtricks.assertHasStrings(
			testhelpers.captureOutput(
				limits.dumpStatsForRD, (self.withstats.rd, self.conn))[1],
			["Statistics for RD data/tstat", "No RD stats.",
				"Statistics for test.DoNotMixCase"])

if __name__=="__main__":
	testhelpers.main(LimitsAnnotatorQueryTest)
