#!/usr/bin/python
#
#       Test fragments of example code to ensure they behave as
#       advertised.
#
#       Functionally, this is similar to doctest, however our input
#       is more regular than that seen by doctest, and therefore our
#       parsing is more relaxed.
#

import sys
import os
import traceback
import StringIO

# Can't use sys.ps1, sys.ps2 as these are only defined for interactive sessions
PS1 = '>>>'
PS2 = '...'


class ExampleFailure(Exception):
    pass


class ExampleCmd:
    def __init__(self, linenum, line):
        self.linenum = linenum
        self.cmds = [line]
        self.expect = []

    def add_continuation(self, line):
        self.cmds.append(line)

    def has_output(self):
        return bool(self.expect)

    def add_output(self, line):
        self.expect.append(line)

    def run(self, namespace):
        got = None
        expect = ''.join(self.expect)
        src = ''.join(self.cmds)
        try:
            saved_stdout = sys.stdout
            sys.stdout = StringIO.StringIO()
            try:
                code = compile(src, '<string>', 'single')
                exec code in namespace
                got = sys.stdout.getvalue()
            finally:
                sys.stdout = saved_stdout
        except:
            # Is it an expected exception?
            if ('Traceback (innermost last):\n' in expect or
                'Traceback (most recent call last):\n' in expect):
                # Only compare exception type and value - the rest of
                # the traceback isn't necessary.
                exc_type, exc_val = sys.exc_info()[:2]
                got = traceback.format_exception_only(exc_type, exc_val)[-1]
                expect = self.expect[-1]
            else:
                # Unexpected exception - print something useful
                exc_type, exc_val, exc_tb = sys.exc_info()
                exc_tb = exc_tb.tb_next
                psrc = PS1 + " " + src.rstrip().replace("\n", PS2 + " ")
                pexp = traceback.format_exception(exc_type, exc_val, exc_tb)
                pexp = ''.join(pexp).rstrip()
                raise ExampleFailure('Line %d:\n%s\n%s' % 
                                     (self.linenum, psrc, pexp))
        if got != expect:
            raise ExampleFailure('Line %d: output does not match example\n'
                                 'Expected:\n%s\nGot:\n%s' % 
                                 (self.linenum, expect, got))


class ExampleTest:
    def __init__(self, filename):
        self.filename = filename
        self.result = ""
        self.example = []

    def report(self):
        if self.result:
            print "=" * 70
            print "Testing \"%s\" failed," % self.filename,
            print self.result
            print

    def test(self):
        try:
            if not self.example:
                self.load_and_parse()
            self.result = ""
            self.execute_verify_output()
        except ExampleFailure, msg:
            self.result = msg

        return self.result == ""

    def load_and_parse(self):
        self.example = []
        try:
            f = open(self.filename)
        except IOError, (eno, estr):
            raise ExampleFailure("could not load: %s" % estr)
        try:
            for linenum, line in enumerate(f):
                if line.startswith(PS1):
                    self.example.append(ExampleCmd(linenum + 1, 
                                                   line[len(PS1)+1:]))
                elif line.startswith(PS2):
                    if not self.example or self.example[-1].has_output():
                        raise ExampleFailure("Line %d: %r line with no preceeding %r line" % (linenum+1, PS2, PS1))
                    self.example[-1].add_continuation(line[len(PS2)+1:])
                else:
                    if not self.example:
                        raise ExampleFailure("Line %d: output with no preceeding %r line" % (linenum+1, PS1))
                    self.example[-1].add_output(line)
        finally:
            f.close()

    def execute_verify_output(self):
        namespace = {}
        for cmd in self.example:
            cmd.run(namespace)


def files_from_directory(dir):
    filenames = []
    for fn in os.listdir(dir):
        filename = os.path.join(dir, fn)
        if fn[0] != '.' and os.path.isfile(filename) \
               and not filename.endswith('.py'):
            filenames.append(filename)
    filenames.sort()
    return filenames


def main():
    if len(sys.argv) > 1:
        filenames = sys.argv[1:]
    else:
        filenames = files_from_directory("doctest")

    fail_cnt = total_cnt = 0
    failures = []

    for filename in filenames:
        if os.path.isfile(filename):
            total_cnt += 1
            e = ExampleTest(filename)
            if not e.test():
                failures.append(e)
                fail_cnt += 1
                sys.stdout.write("!")
            else:
                sys.stdout.write(".")
            sys.stdout.flush()
    sys.stdout.write("\n")

    for e in failures:
        e.report()

    print "%d of %d tests failed" % (fail_cnt, total_cnt)

    sys.exit(fail_cnt > 0)


if __name__ == '__main__':
    main()

