"""
Tests for generic UWS

If it's TAP we're talking about, use taptest.py.
"""

#c Copyright 2008-2024, the GAVO project <gavo@ari.uni-heidelberg.de>
#c
#c This program is free software, covered by the GNU GPL.  See the
#c COPYING file in the source distribution.


import datetime
import queue
import time
import threading

from gavo.helpers import testhelpers

from gavo import base
from gavo import rscdesc #noflake: for base.caches registration
from gavo import svcs
from gavo import utils
from gavo.helpers import trialhelpers
from gavo.protocols import dali
from gavo.protocols import uws
from gavo.protocols import uwsactions

import tresc

class _PlainTransitions(uws.UWSTransitions):
	def __init__(self):
		uws.UWSTransitions.__init__(self, "plain", [
			(uws.PENDING, uws.QUEUED, "noOp"),
			(uws.PENDING, uws.EXECUTING, "run"),
			(uws.QUEUED, uws.EXECUTING, "run"),
			(uws.EXECUTING, uws.COMPLETED, "noOp"),
			(uws.QUEUED, uws.ABORTED, "noOp"),
			(uws.EXECUTING, uws.ABORTED, "noOp"),])
	
	def run(self, newState, writableJob, ignored):
		writableJob.change(phase=uws.COMPLETED)


class UWSTestJob(uws.BaseUWSJob):
	_jobsTDId = "data/uwstest#testjobs"
	_transitions = _PlainTransitions()


_TEST_UWS = uws.UWS(UWSTestJob, uwsactions.JobActions())


class _TestUWSTable(tresc.RDDataResource):
	rdName = "data/uwstest"
	dataId = "import"

_testUWSTable = _TestUWSTable()


class UWSObjectTest(testhelpers.VerboseTest):
	resources = [("testUWSTable", _testUWSTable)]

	def testGetStatement(self):
		self.assertEqual(_TEST_UWS._statements["getById"][1],
			"SELECT jobId, phase, executionDuration, destructionTime,"
			" owner, parameters, runId, startTime, endTime,"
			" error, creationTime, magic"
			" FROM test.testjobs WHERE jobId=%(jobId)s ")

	def testExGetStatement(self):
		self.assertEqual(_TEST_UWS._statements["getByIdEx"][1],
			"SELECT jobId, phase, executionDuration, destructionTime,"
			" owner, parameters, runId, startTime, endTime,"
			" error, creationTime, magic"
			" FROM test.testjobs WHERE jobId=%(jobId)s FOR UPDATE ")

	def testFeedStatement(self):
		self.assertEqual(_TEST_UWS._statements["feedToIdEx"][1],
			'INSERT INTO test.testjobs (jobId, phase, executionDuration,'
			' destructionTime, owner, parameters, runId,'
			' startTime, endTime, error, creationTime, magic) VALUES'
			' (%(jobId)s, %(phase)s,'
			' %(executionDuration)s, %(destructionTime)s, %(owner)s,'
			' %(parameters)s, %(runId)s, %(startTime)s, %(endTime)s,'
			' %(error)s, %(creationTime)s, %(magic)s)')
	
	def testJobsTDCache(self):
		td1 = UWSTestJob.jobsTD
		td2 = UWSTestJob.jobsTD
		self.assertEqual(td1.columns[0].name, "jobId")
		self.assertTrue(td1 is td2)

	def testCountFunctions(self):
		jobId = _TEST_UWS.getNewJobId()
		self.assertEqual(_TEST_UWS.countQueuedJobs(), 0)
		_TEST_UWS.changeToPhase(jobId, uws.QUEUED)
		try:
			self.assertEqual(_TEST_UWS.countQueuedJobs(), 1)
			self.assertEqual(_TEST_UWS.countRunningJobs(), 0)
		finally:
			_TEST_UWS.destroy(jobId)

	def testNotFoundRaised(self):
		self.assertRaises(uws.JobNotFound,
			_TEST_UWS.getJob,
			"there's no way this id could ever exist")


class TestWithUWSJob(testhelpers.VerboseTest):
	resources = [("testUWSTable", _testUWSTable)]

	def setUp(self):
		self.job = _TEST_UWS.getNewJob()
		testhelpers.VerboseTest.setUp(self)
	
	def tearDown(self):
		try:
			_TEST_UWS.destroy(self.job.jobId)
		except uws.JobNotFound:
			# tests may kill jobs themselves
			pass


class SimpleJobsTest(TestWithUWSJob):
	def testNonPropertyRaises(self):
		self.assertRaises(AttributeError,
			lambda: self.job.foobar)

	def testCannotChangeRO(self):
		self.assertRaises(TypeError,
			self.job.change, owner="rupert")

	def testChangeWorks(self):
		self.assertEqual(self.job.owner, None)
		with _TEST_UWS.changeableJob(self.job.jobId) as wjob:
			self.assertEqual(wjob.owner, None)
			wjob.change(owner="rupert")
			self.assertEqual(wjob.owner, "rupert")
		self.job.update()
		self.assertEqual(self.job.owner, "rupert")

	def testDefaultInPlace(self):
		self.assertEqual(self.job.phase, uws.PENDING)

	def testAssigningIsForbidden(self):
		def fails():
			with _TEST_UWS.changeableJob(self.job.jobId) as wjob:
				wjob.phase = uws.PENDING
		self.assertRaises(TypeError, fails)

	def testAssigningToNonexistingRaises(self):
		with _TEST_UWS.changeableJob(self.job.jobId) as wjob:
			self.assertRaises(AttributeError, wjob.change, foo="bar")

	def testPropertyEncoding(self):
		with _TEST_UWS.changeableJob(self.job.jobId) as wjob:
			wjob.change(parameters={"foo": 4})
		self.job.update()
		self.assertEqual(self.job.parameters["foo"], 4)

	def testNoParameterSerInRO(self):
		def fails():
			self.job.setPar("glob", "someString")
		self.assertRaises(TypeError, fails)

	def testNonmagicParameterSer(self):
		with _TEST_UWS.changeableJob(self.job.jobId) as wjob:
			wjob.setPar("glob", "some string")
		self.job.update()
		self.assertEqual(self.job.parameters["glob"], "some string")

	def testCreationTime(self):
		timediff = self.job.creationTime-datetime.datetime.utcnow()
		self.assertTrue(abs(timediff.total_seconds())<2,
			"Job creation to test execution >2 seconds??")


class PlainActionsTest(TestWithUWSJob):
	def testSimpleTransition(self):
		_TEST_UWS.changeToPhase(self.job.jobId, uws.QUEUED)
		self.job.update()
		self.assertEqual(self.job.phase, uws.QUEUED)

	def testTransitionWithCallback(self):
		_TEST_UWS.changeToPhase(self.job.jobId, uws.EXECUTING)
		self.job.update()
		self.assertEqual(self.job.phase, uws.COMPLETED)

	def testFailingTransition(self):
		with _TEST_UWS.changeableJob(self.job.jobId) as wjob:
			wjob.change(phase=uws.EXECUTING)
			self.assertRaises(base.ValidationError,
				wjob.getTransitionTo,
				uws.QUEUED)

	def testFailingGoesToError(self):
		with _TEST_UWS.changeableJob(self.job.jobId) as wjob:
			wjob.change(phase=uws.EXECUTING)
		_TEST_UWS.changeToPhase(self.job.jobId, uws.QUEUED)
		self.job.update()
		self.assertEqual(self.job.phase, uws.ERROR)
	
	def testNoEndstateActions(self):
		with _TEST_UWS.changeableJob(self.job.jobId) as wjob:
			wjob.change(phase=uws.COMPLETED)
		_TEST_UWS.changeToPhase(self.job.jobId, uws.ERROR)
		self.job.update()
		self.assertEqual(self.job.phase, uws.COMPLETED)

	def testNullActionIgnored(self):
		with _TEST_UWS.changeableJob(self.job.jobId) as wjob:
			wjob.change(phase=uws.QUEUED)
		_TEST_UWS.changeToPhase(self.job.jobId, uws.QUEUED)
		self.job.update()
		self.assertEqual(self.job.phase, uws.QUEUED)


class JobHandlingTest(TestWithUWSJob):
	def testCleanupLeaves(self):
		_TEST_UWS.cleanupJobsTable()
		self.assertTrue(self.job.jobId in _TEST_UWS.getJobIds())
	
	def testCleanupCatchesExpired(self):
		with _TEST_UWS.changeableJob(self.job.jobId) as wjob:
			wjob.change(destructionTime=
				datetime.datetime.utcnow()-datetime.timedelta(seconds=20))
		_TEST_UWS.cleanupJobsTable()
		self.assertFalse(self.job.jobId in _TEST_UWS.getJobIds())

	def testCleanupWithArgs(self):
		_TEST_UWS.cleanupJobsTable(includeAll=True)
		self.assertFalse(self.job.jobId in _TEST_UWS.getJobIds())

	def testLocking(self):
		q = queue.Queue()
		def blockingJob():
			# this is started in a thread while self.jobId is held
			q.put("Child started")
			with _TEST_UWS.changeableJob(self.job.jobId):
				q.put("Job created")

		with _TEST_UWS.changeableJob(self.job.jobId):
			child = threading.Thread(target=blockingJob)
			child.start()
			# see that child process has started but could not create the job
			self.assertEqual(q.get(True, 1), "Child started")
			# make sure we time out on waiting for another sign of the child --
			# it should be blocking.
			self.assertRaises(queue.Empty, q.get, True, 0.05)
		# we've closed our handle on job, now child can run
		self.assertEqual(q.get(True, 1), "Job created")

	def testTimesOut(self):
		def timesOut(jobId):
			with _TEST_UWS.changeableJob(jobId, timeout=0.001):
				self.fail("No lock happened???")

		with _TEST_UWS.changeableJob(self.job.jobId):
			self.assertRaisesWithMsg(base.ReportableError,
				"Could not access the jobs table. This probably means there"
				" is a stale lock on it.  Please notify the service operators.",
				timesOut,
				(self.job.jobId,))


def _makeUWSRequest(inArgs, inMethod="GET"):
	"""returns a FakeRequest prepared the way the UWS renderers prepare
	normal twisted requests.
	"""
	res = trialhelpers.FakeRequest(args=inArgs)
	res.uwsArgs = uws._getUWSGrammar().parseStrargs(inArgs)
	res.method = utils.bytify(inMethod)
	dali.mangleUploads(res)
	return res


class UserUWSTest(testhelpers.VerboseTest):
	def testBasicJob(self):
		worker = base.resolveCrossId("data/cores#pc").getUWS()
		job = worker.getNewJob()
		try:
			self.assertEqual(job.jobClass, "data/cores#pc")
			with job.getWritable() as wjob:
				wjob.setPar("opre", 2.5)
				wjob.setPar("powers", [1,2,3])
			job.update()
			self.assertEqual(job.parameters["opre"], 2.5)
			self.assertEqual(job.parameters["powers"], [1,2,3])
		finally:
			worker.destroy(job.jobId)

	def testFloatParameter(self):
		worker = base.resolveCrossId("data/cores#pc").getUWS()
		jobId = worker.getNewIdFromRequest(
			_makeUWSRequest({"opre": ["2.5489923488e10"], "opim": ["3.5"]}))
		try:
			job = worker.getJob(jobId)
			self.assertEqual(job.parameters["opre"], 2.5489923488e10)
		finally:
			worker.destroy(jobId)

	def testArrayParameter(self):
		worker = base.resolveCrossId("data/cores#pc").getUWS()
		jobId = worker.getNewIdFromRequest(
			_makeUWSRequest({"powers": ["2 4 78"]}))
		try:
			job = worker.getJob(jobId)
			self.assertEqual(job.parameters["powers"], [2, 4, 78])
		finally:
			worker.destroy(jobId)

	def testParametersElement(self):
		worker = base.resolveCrossId("data/cores#pc").getUWS()
		jobId = worker.getNewIdFromRequest(_makeUWSRequest({
				"powers": ["2 4 78"],
				"upload": ["stuff,param:quux"],
				"quux": trialhelpers.FakeFile("honk.txt", "gacker")}))
		try:
			res = uwsactions.doJobAction(worker, _makeUWSRequest({}),
				(jobId,))
			tree = testhelpers.getXMLTree(res, debug=False)

			self.assertEqual(len(tree.xpath("//parameter")), 5)
			self.assertEqual(tree.xpath("//parameter[@id='opim']")[0].text,
				"1.0")

			upEl = tree.xpath("//parameter[@id='stuff']")[0]
			self.assertEqual(upEl.attrib["byReference"], "True")
			self.assertEqual(upEl.text,
				utils.EqualingRE("http://localhost:8080/data/cores/"
					"pc/async/.*/results/stuff"))
		finally:
			worker.destroy(jobId)

	def testUpload(self):
		worker = base.resolveCrossId("data/cores#pc").getUWS()
		jobId = worker.getNewIdFromRequest(
			_makeUWSRequest({
				"upload": ["stuff,param:quux"],
				"quux": trialhelpers.FakeFile("honk.txt", "gacker")}))
		try:
			job = worker.getJob(jobId)
			
			with job.openFile(job.parameters["stuff"][1]) as f:
				self.assertEqual(f.read(), b"gacker")

			try:
				uwsactions.doJobAction(worker, _makeUWSRequest({
						"upload": ["stuff,param:foo"],
							"foo": trialhelpers.FakeFile("bigbig.txt", "bigbig")},
						inMethod="POST"),
					(job.jobId, "parameters"))
			except svcs.SeeOther:
				pass

			with job.openFile(job.parameters["stuff"][1]) as f:
				self.assertEqual(f.read(), b"bigbig")
		finally:
			worker.destroy(job.jobId)

	def testProperClassDeserialisation(self):
		worker1 = base.resolveCrossId("data/cores#pc").getUWS()
		worker2 = base.resolveCrossId("data/cores#uc").getUWS()
		jobId = worker1.getNewIdFromRequest(
			_makeUWSRequest({"opre": ["29"], "powers": "2 3 4".split()}))
		try:
			try:
				worker2.getJob(jobId)
			except svcs.WebRedirect as ex:
				self.assertEqual(ex.dest,
					"http://localhost:8080/data/cores/pc/async/"+jobId)
			else:
				self.fail("Construction of job with wrong UWS doesn't redirect")
		finally:
			worker1.destroy(jobId)

	def testRunningAndDefaulting(self):
		worker = base.resolveCrossId("data/cores#pc").getUWS()
		jobId = worker.getNewIdFromRequest(
			_makeUWSRequest({"powers": ["2 4"], "opre": ["4"]}))
		try:
			job = worker.getJob(jobId)
			worker.changeToPhase(job.jobId, uws.QUEUED, timeout=0)
			for _ in range(100):
				time.sleep(0.3)
				if job.update().phase==uws.COMPLETED:
					break
			else:
				raise AssertionError(
					"useruws test job didn't complete within 30 seconds.")

			with job.openFile("result") as f:
				res = f.read()

			# that's the info with opim, which is defaulted and hence
			# must come from the context grammar.
			self.assertTrue(b"1.0</INFO" in  res)
			# when this breaks, add proper VOTable parsing here
			self.assertTrue(b"QXAAAEEAAABANVNeQyEAAENwAABAtVNe" in res)
		finally:
			worker.destroy(jobId)


if __name__=="__main__":
	testhelpers.main(JobHandlingTest)
