#!/usr/bin/env python

# StressSuite, a unittest testSuite designed for stressing software and
# catching memory leaks.
#
# Copyright (C) 2001  Christian Reis <kiko@async.com.br>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Library General Public
# License as published by the Free Software Foundation; either
# version 2 of the License, or (at your option) any later version.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Library General Public License for more details.
# 
# You should have received a copy of the GNU Library General Public
# License along with this library; if not, write to the Free
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA

import unittest, os, string, sys, time

# Needs to have certain number of variables set externally. Okay, so
# this is ugly. Sue me. :)

RUNS = 1000
NUM_SAMPLES = 30

TESTS = []
DEBUG = 0

# A simple stat class that monitors a single process for growth

class SingleStressStats:
    def __init__(self):
        self.mypid = str(os.getpid())
        self.size = self.get_size()
        self.start_size = self.get_size()
        self.leak = 0
        self.startTime = time.time()

    def get_size(self):
        return self.get_vmsize(self.mypid)

    def get_vmsize(self, pid):
        sizeh = open('/proc/'+pid+'/status')
        lines = sizeh.readlines()
        sizeh.close()
        for l in lines:
            if l[0:6] == "VmSize":
                return int(string.split(l)[1])
        raise "/proc/#/status broken - not found"

    def __str__(self):
        self.update()
        curr = self.get_size()
        diff = curr - self.start_size
        tdiff = time.time() - self.startTime
        stats = "(%02.2fs)" % ( tdiff ) 
        # if we leak more than half the times
        if ( self.leak > NUM_SAMPLES * 0.5 ):
            stats = stats + \
            "\n    *** Leak in process: %.2f Kbytes, %d bytes per invocation" % \
                    ( diff, (diff*1024)/ (RUNS) )
        return stats

    def update(self):
        curr = self.get_vmsize(self.mypid)
        # do x-y for curr and self.sizes in parallel
        diff = curr - self.size
        self.sizes = curr

        # Attention: leak is a counter because we want to catch
        # reproducible leaks, and not just a single leak (which could be
        # caused by python startup or one-time growth
        if diff:
            self.leak = self.leak+1
        return self.leak

# A StressStats class that reads information from /proc for two processes, a
# client and a server.

class ClientServerStressStats:

    def __init__(self):
        self.mypid = str(os.getpid())
        # get server pid from "pid" file.
        self.serverpid = string.strip( open('pid').readline() )
        self.sizes = self.get_size()
        self.start_sizes = self.get_size()
        # store leaks in client and server
        self.leak = [ 0, 0 ]
        self.startTime = time.time()

    def get_size(self):
        return [ self.get_vmsize(self.serverpid),
            self.get_vmsize(self.mypid) ]

    # Reach for the crack pipe!
    def __str__(self):
        self.update()
        curr = self.get_size()
        diff = map(lambda x,y: x - y, curr, self.start_sizes)
        tdiff = time.time() - self.startTime
        stats = "(%02.2fs)" % ( tdiff ) 
        if ( self.leak[0] > NUM_SAMPLES / 2  ):
            stats = stats + \
            "\n    *** Leak in server: %.2f Kbytes, %d bytes per invocation" % \
                    ( diff[0], (diff[0]*1024)/ (RUNS) )
        if ( self.leak[1] > NUM_SAMPLES / 2  ):
            stats = stats + \
            "\n    *** Leak in client: %.2f KBytes, %d bytes per invocation" % \
                    ( diff[1], (diff[1]*1024)/ (RUNS) )
        return stats

    def get_vmsize(self, pid):
        sizeh = open('/proc/'+pid+'/status')
        lines = sizeh.readlines()
        sizeh.close()
        for l in lines:
            if l[0:6] == "VmSize":
                return int(string.split(l)[1])
        raise "/proc/#/status broken - not found"

    def update(self):
        curr = [ self.get_vmsize(self.serverpid),
            self.get_vmsize(self.mypid) ]
        # do x-y for curr and self.sizes in parallel
        diff = map(lambda x,y: x - y, curr, self.sizes)
        self.sizes = curr

        # Attention: leak is a counter because we want to catch
        # reproducible leaks, and not just a single leak (which could be
        # caused by python startup or one-time growth
        if diff[0]:
            self.leak[0] = self.leak[0]+1
        if diff[1]:
            self.leak[1] = self.leak[1]+1

        return self.leak[0] + self.leak[1]

#
# Override this if you don't want the default C/S test.
#

STATS = ClientServerStressStats

#
# The suite itself.
#

class StressSuite(unittest.TestSuite):
    def __init__(self, tests=() ):
        unittest.TestSuite.__init__(self, tests)

    # This is to get the tests running in the right order
    # Names are of the type X## where X is any letter.
    def _cmp_names(self,x,y):
        nx = string.split(str(x),"_")[1]
        ny = string.split(str(y),"_")[1]

        # Don't try and reorder things with broken names
        try:
            int(nx[1:])
        except ValueError:
            return 0
            
        if int(nx[1:]) > int(ny[1:]):
            return 1
        else: 
            return -1

    def echo(self,str):
        sys.stdout.write(str)
        sys.stdout.flush()

    def __call__(self, result):
    
        SAMPLE_INTERVAL = int ( RUNS / NUM_SAMPLES )

        self.echo("\n")
        n = str("Test name")
        self.echo(n)
        for i in range(0,30 - len(n)):
            self.echo(" ")
        n = " Progress"
        self.echo(n)
        for i in range(0,NUM_SAMPLES + 3 - len(n)):
            self.echo(" ")
        self.echo(" Time")
        self.echo("\n\n")

        # get those tests in the right order
        self._tests.sort(self._cmp_names)
        for test in self._tests:

            # filter our nostress tests
            if str(test).find("nostress") != -1:
                continue

            # if TESTS exists, only run tests listed
            if TESTS:
                if string.split(str(test),"_")[1] not in TESTS:
                    continue

            stats = STATS()
            # No idea what this is - stolen from TestSuite
            if result.shouldStop:
                break
            # Cut up name into something nice
            fullname = string.split(str(test)," ")[0]
            testname = string.split(fullname,"_",1)[1]
            # truncate name and fix open progress bracket.
            self.echo(testname[:28])
            for i in range(0,30 - len(testname[:28])):
                self.echo(" ")

            # The testing loop
            self.echo("[")
            for i in range(0,RUNS):
                if DEBUG: print i
                test(result)
                if not i % ( SAMPLE_INTERVAL ):
                    r = stats.update()
                    if r > NUM_SAMPLES * 0.75:
                        self.echo("O")
                    elif r > NUM_SAMPLES * 0.5:
                        self.echo("o")
                    elif r > NUM_SAMPLES * 0.25:
                        self.echo(".")
                    else:
                        self.echo("_")
            self.echo("] "+str(stats)+"\n")

            # Catch multiple errors and just use first one
            if len(result.errors) > 0:
                result.errors = [ result.errors[0], ]
            if len(result.failures) > 0:
                result.failures = [ result.failures[0], ]

        return result

#
# Example run.
#

if __name__ == "__main__":

    STATS = SingleStressStats
    RUNS = 1000

    class FooTest(unittest.TestCase):
        hoard = []
        def test_A0_nop(self):
            pass
        def test_A1_alloc(self):
            self.hoard.append(open("/dev/zero").read(2048))
        def test_A2_alloc_big(self):
            self.hoard.append(open("/dev/zero").read(8192))

    suites = unittest.makeSuite ( FooTest, suiteClass = StressSuite ) 
    suite = unittest.TestSuite ( ( suites, ) ) 
    runner = unittest.TextTestRunner( verbosity = 0)
    runner.run(suite)
