# Copyright (c) 2005-2007 Forest Bond.
# This file is part of the sclapp software package.
# 
# sclapp is free software; you can redistribute it and/or modify it under the
# terms of the GNU General Public License version 2 as published by the Free
# Software Foundation.
# 
# A copy of the license has been included in the COPYING file.

import sys, os, re
from unittest import TestCase

import sclapp
from sclapp import debug_logging

def execHead(num_lines = 5):
    pid = os.fork()
    if not pid:
        os.execvp('head', [ 'head', '-n%u' % num_lines ])
    return pid

CAUGHT_SIGNALS_REGEX = r'pid ([0-9]+) caught signals: ([0-9, ]*)'
caught_signals_regex_compiled = re.compile(CAUGHT_SIGNALS_REGEX)

def dumpLogFile():
    print 'Logfile contents:'
    print 80 * '-'
    sys.stdout.write(debug_logging.readLogFile())
    print 80 * '-'

def getLoggedSignals(pid):
    contents = debug_logging.readLogFile()
    for line in contents.split('\n'):
        match = caught_signals_regex_compiled.match(line)
        if match is not None:
            gs = match.groups()
            assert(len(gs) == 2)
            logged_pid = int(gs[0])
            if pid == logged_pid:
                signums = [ int(x) for x in [
                  x.strip() for x in gs[1].split(',')
                ] if x ]
                return signums

def verifySignalCaught(signum, pid):
    signums = getLoggedSignals(pid)
    return ((signums is not None) and (signum in signums))

def assertSignalCaught(signum, pid):
    assert verifySignalCaught(signum, pid)

def logSignals():
    debug_logging.logMessage('pid %u caught signals: %s' % \
      (os.getpid(), ', '.join([str(x) for x in sclapp.getCaughtSignals()])))

def waitForPid(pid):
    return os.waitpid(pid, 0)

def removeLogFile():
    try:
        return debug_logging.removeLogFile()
    except (OSError, IOError):
        pass

def redirectToLogFile():
    from sclapp import processes as s_processes
    return s_processes.redirectFds(
      stdout = debug_logging.DEBUG_LOGFILE,
      stderr = debug_logging.DEBUG_LOGFILE
    )

def assertLogFileContains(needle):
    haystack = debug_logging.readLogFile()
    assert (haystack.find(needle) > -1)

def assertLogFileDoesNotContain(needle):
    haystack = debug_logging.readLogFile()
    assert (haystack.find(needle) < 0)

def grepCount(haystack, needle):
    count = 0
    i = -1
    while True:
        i = haystack.find(needle, i + 1)
        if i == -1:
            break
        count = count + 1
    return count

def assertLogFileContainsExactly(needle, num):
    haystack = debug_logging.readLogFile()
    count = grepCount(haystack, needle)
    assert (count == num), (
      'Expected exactly %u, found %u.  Logfile:\n%s' % (
        num, count, debug_logging.readLogFile()
    ))

def assertLogFileContainsAtLeast(needle, min):
    haystack = debug_logging.readLogFile()
    count = grepCount(haystack, needle)
    assert (count >= min), (
      'Expected at least %u, found %u.  Logfile:\n%s' % (
        min, count, debug_logging.readLogFile()
    ))

def assertLogFileContainsAtMost(needle, max):
    haystack = debug_logging.readLogFile()
    count = grepCount(haystack, needle)
    assert (count <= max), (
      'Expected at most %u, found %u.  Logfile:\n%s' % (
        max, count, debug_logging.readLogFile()
    ))

class SclappTestCase(TestCase):
    def setUp(self):
        removeLogFile()
        return super(SclappTestCase, self).setUp()

    def tearDown(self):
        removeLogFile()
        return super(SclappTestCase, self).tearDown()

TEST_DIR = os.path.abspath(os.path.dirname(__file__))
PROJECT_DIR = os.path.dirname(TEST_DIR)
