
# Jython Database Specification API 2.0
#
# $Id: runner.py,v 1.1 2001/12/14 04:20:03 bzimmer Exp $
#
# Copyright (c) 2001 brian zimmer <bzimmer@ziclix.com>

"""
To run the tests, simply invoke this script from the commandline:

 jython runner.py <xml config file> [vendor, ...]

If no vendors are given, then all vendors will be tested.  If a
vendor is given, then only that vendor will be tested.
"""

import unittest, sys, os
import xmllib, __builtin__, re

def __imp__(module, attr=None):
	if attr:
		j = __import__(module, globals(), locals())
		return getattr(j, attr)
	else:
		last = module.split(".")[-1]
		return __import__(module, globals(), locals(), last)

class Factory:
	def __init__(self, classname, method):
		self.classname = classname
		self.method = method
		self.arguments = []
		self.keywords = {}

class Testcase:
	def __init__(self, frm, impt):
		self.frm = frm
		self.impt = impt
		self.ignore = []

class Test:
	def __init__(self, name, os):
		self.name = name
		self.os = os
		self.factory = None
		self.tests = []

class Vendor:
	def __init__(self, name, datahandler=None):
		self.name = name
		self.datahandler = datahandler
		self.tests = []
		self.tables = {}

class ConfigParser(xmllib.XMLParser):
	"""
	A simple XML parser for the config file.
	"""
	def __init__(self, **kw):
		apply(xmllib.XMLParser.__init__, (self,), kw)
		self.vendors = []
		self.table_stack = []
		self.re_var = re.compile(r"\${(.*?)}")

	def vendor(self):
		assert len(self.vendors) > 0, "no vendors"
		return self.vendors[-1]

	def test(self):
		v = self.vendor()
		assert len(v.tests) > 0, "no tests"
		return v.tests[-1]

	def factory(self):
		c = self.test()
		assert c.factory, "no factory"
		return c.factory

	def testcase(self):
		s = self.test()
		assert len(s.tests) > 0, "no testcases"
		return s.tests[-1]

	def value(self, value):
		def repl(sub):
			from java.lang import System
			return System.getProperty(sub.group(1), sub.group(1))
		value = self.re_var.sub(repl, value)
		return value

	def start_vendor(self, attrs):
		if attrs.has_key('datahandler'):
			v = Vendor(attrs['name'], attrs['datahandler'])
		else:
			v = Vendor(attrs['name'])
		self.vendors.append(v)

	def start_test(self, attrs):
		v = self.vendor()
		c = Test(attrs['name'], attrs['os'])
		v.tests.append(c)

	def start_factory(self, attrs):
		c = self.test()
		f = Factory(attrs['class'], attrs['method'])
		c.factory = f

	def start_argument(self, attrs):
		f = self.factory()
		if attrs.has_key('type'):
			f.arguments.append((attrs['name'], getattr(__builtin__, attrs['type'])(self.value(attrs['value']))))
		else:
			f.arguments.append((attrs['name'], self.value(attrs['value'])))

	def start_keyword(self, attrs):
		f = self.factory()
		if attrs.has_key('type'):
			f.keywords[attrs['name']] = getattr(__builtin__, attrs['type'])(self.value(attrs['value']))
		else:
			f.keywords[attrs['name']] = self.value(attrs['value'])

	def start_ignore(self, attrs):
		t = self.testcase()
		t.ignore.append(attrs['name'])

	def start_testcase(self, attrs):
		c = self.test()
		c.tests.append(Testcase(attrs['from'], attrs['import']))

	def start_table(self, attrs):
		self.table_stack.append((attrs['ref'], attrs['name']))

	def end_table(self):
		del self.table_stack[-1]

	def handle_data(self, data):
		if len(self.table_stack):
			ref, tabname = self.table_stack[-1]
			self.vendor().tables[ref] = (tabname, data.strip())

class SQLTestCase(unittest.TestCase):
	"""
	Base testing class.  It contains the list of table and factory information
	to run any tests.
	"""
	def __init__(self, name, vendor, factory):
		unittest.TestCase.__init__(self, name)
		self.vendor = vendor
		self.factory = factory
		if self.vendor.datahandler:
			self.datahandler = __imp__(self.vendor.datahandler)

	def table(self, name):
		return self.vendor.tables[name]

	def has_table(self, name):
		return self.vendor.tables.has_key(name)

def make_suite(vendor, testcase, factory):
	clz = __imp__(testcase.frm, testcase.impt)
	caseNames = filter(lambda x, i=testcase.ignore: x not in i, unittest.getTestCaseNames(clz, "test"))
	tests = [clz(caseName, vendor, factory) for caseName in caseNames]
	return unittest.TestSuite(tests)

def test(vendors, include=None):
	for vendor in vendors:
		if not include or vendor.name in include:
			print
			print "testing [%s]" % (vendor.name)
			for test in vendor.tests:
				if not test.os or test.os == os.name:
					for testcase in test.tests:
						suite = make_suite(vendor, testcase, test.factory)
						unittest.TextTestRunner().run(suite)
		else:
			print
			print "skipping [%s]" % (vendor.name)

if __name__ == '__main__':
	configParser = ConfigParser()
	fp = open(sys.argv[1], "r")
	configParser.feed(fp.read())
	fp.close()
	test(configParser.vendors, sys.argv[2:])
