"""
Copyright (c) 2008-2020, Jesus Cea Avion <jcea@jcea.es>
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:

    1. Redistributions of source code must retain the above copyright
    notice, this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above
    copyright notice, this list of conditions and the following
    disclaimer in the documentation and/or other materials provided
    with the distribution.

    3. Neither the name of Jesus Cea Avion nor the names of its
    contributors may be used to endorse or promote products derived
    from this software without specific prior written permission.

    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
    CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
    INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
    MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
    DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
    BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
            TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
            DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
    ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
    TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
    THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
    SUCH DAMAGE.
    """

"""TestCases for multi-threaded access to a DB.
"""

import os
import sys
import time
import errno
from random import random

DASH = '-'

try:
    WindowsError
except NameError:
    class WindowsError(Exception):
        pass

import unittest
from test_all import db, dbutils, test_support, verbose, have_threads, \
        get_new_environment_path, get_new_database_path

if have_threads :
    from threading import Thread
    if sys.version_info[0] < 3 :
        from threading import currentThread
    else :
        from threading import current_thread as currentThread


#----------------------------------------------------------------------

class BaseThreadedTestCase(unittest.TestCase):
    dbtype       = db.DB_UNKNOWN  # must be set in derived class
    dbopenflags  = 0
    dbsetflags   = 0
    envflags     = 0

    def setUp(self):
        if verbose:
            dbutils._deadlock_VerboseFile = sys.stdout

        self.homeDir = get_new_environment_path()
        self.env = db.DBEnv()
        self.setEnvOpts()
        self.env.open(self.homeDir, self.envflags | db.DB_CREATE)

        self.filename = self.__class__.__name__ + '.db'
        self.d = db.DB(self.env)
        if self.dbsetflags:
            self.d.set_flags(self.dbsetflags)
        self.d.open(self.filename, self.dbtype, self.dbopenflags|db.DB_CREATE)

    def tearDown(self):
        self.d.close()
        self.env.close()
        test_support.rmtree(self.homeDir)

    def setEnvOpts(self):
        pass

    def makeData(self, key):
        return DASH.join([key] * 5)


#----------------------------------------------------------------------


class ConcurrentDataStoreBase(BaseThreadedTestCase):
    dbopenflags = db.DB_THREAD
    envflags    = db.DB_THREAD | db.DB_INIT_CDB | db.DB_INIT_MPOOL
    readers     = 0 # derived class should set
    writers     = 0
    records     = 1000

    def test01_1WriterMultiReaders(self):
        if verbose:
            print '\n', '-=' * 30
            print "Running %s.test01_1WriterMultiReaders..." % \
                  self.__class__.__name__

        keys=range(self.records)
        import random
        random.shuffle(keys)
        records_per_writer=self.records//self.writers
        readers_per_writer=self.readers//self.writers
        self.assertEqual(self.records,self.writers*records_per_writer)
        self.assertEqual(self.readers,self.writers*readers_per_writer)
        self.assertTrue((records_per_writer%readers_per_writer)==0)
        readers = []

        for x in xrange(self.readers):
            rt = Thread(target = self.readerThread,
                        args = (self.d, x),
                        name = 'reader %d' % x,
                        )#verbose = verbose)
            if sys.version_info[0] < 3 :
                rt.setDaemon(True)
            else :
                rt.daemon = True
            readers.append(rt)

        writers=[]
        for x in xrange(self.writers):
            a=keys[records_per_writer*x:records_per_writer*(x+1)]
            a.sort()  # Generate conflicts
            b=readers[readers_per_writer*x:readers_per_writer*(x+1)]
            wt = Thread(target = self.writerThread,
                        args = (self.d, a, b),
                        name = 'writer %d' % x,
                        )#verbose = verbose)
            writers.append(wt)

        for t in writers:
            if sys.version_info[0] < 3 :
                t.setDaemon(True)
            else :
                t.daemon = True
            t.start()

        for t in writers:
            t.join()
        for t in readers:
            t.join()

    def writerThread(self, d, keys, readers):
        if sys.version_info[0] < 3 :
            name = currentThread().getName()
        else :
            name = currentThread().name

        if verbose:
            print "%s: creating records %d - %d" % (name, start, stop)

        count=len(keys)//len(readers)
        count2=count
        for x in keys :
            key = '%04d' % x
            dbutils.DeadlockWrap(d.put, key, self.makeData(key),
                                 max_retries=12)
            if verbose and x % 100 == 0:
                print "%s: records %d - %d finished" % (name, start, x)

            count2-=1
            if not count2 :
                readers.pop().start()
                count2=count

        if verbose:
            print "%s: finished creating records" % name

        if verbose:
            print "%s: thread finished" % name

    def readerThread(self, d, readerNum):
        if sys.version_info[0] < 3 :
            name = currentThread().getName()
        else :
            name = currentThread().name

        for i in xrange(5) :
            c = d.cursor()
            count = 0
            rec = c.first()
            while rec:
                count += 1
                key, data = rec
                self.assertEqual(self.makeData(key), data)
                rec = c.next()
            if verbose:
                print "%s: found %d records" % (name, count)
            c.close()

        if verbose:
            print "%s: thread finished" % name


class BTreeConcurrentDataStore(ConcurrentDataStoreBase):
    dbtype  = db.DB_BTREE
    writers = 2
    readers = 10
    records = 1000


class HashConcurrentDataStore(ConcurrentDataStoreBase):
    dbtype  = db.DB_HASH
    writers = 2
    readers = 10
    records = 1000


#----------------------------------------------------------------------

class SimpleThreadedBase(BaseThreadedTestCase):
    dbopenflags = db.DB_THREAD
    envflags    = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
    readers = 10
    writers = 2
    records = 1000

    def setEnvOpts(self):
        self.env.set_lk_detect(db.DB_LOCK_DEFAULT)

    def test02_SimpleLocks(self):
        if verbose:
            print '\n', '-=' * 30
            print "Running %s.test02_SimpleLocks..." % self.__class__.__name__


        keys=range(self.records)
        import random
        random.shuffle(keys)
        records_per_writer=self.records//self.writers
        readers_per_writer=self.readers//self.writers
        self.assertEqual(self.records,self.writers*records_per_writer)
        self.assertEqual(self.readers,self.writers*readers_per_writer)
        self.assertTrue((records_per_writer%readers_per_writer)==0)

        readers = []
        for x in xrange(self.readers):
            rt = Thread(target = self.readerThread,
                        args = (self.d, x),
                        name = 'reader %d' % x,
                        )#verbose = verbose)
            if sys.version_info[0] < 3 :
                rt.setDaemon(True)
            else :
                rt.daemon = True
            readers.append(rt)

        writers = []
        for x in xrange(self.writers):
            a=keys[records_per_writer*x:records_per_writer*(x+1)]
            a.sort()  # Generate conflicts
            b=readers[readers_per_writer*x:readers_per_writer*(x+1)]
            wt = Thread(target = self.writerThread,
                        args = (self.d, a, b),
                        name = 'writer %d' % x,
                        )#verbose = verbose)
            writers.append(wt)

        for t in writers:
            if sys.version_info[0] < 3 :
                t.setDaemon(True)
            else :
                t.daemon = True
            t.start()

        for t in writers:
            t.join()
        for t in readers:
            t.join()

    def writerThread(self, d, keys, readers):
        if sys.version_info[0] < 3 :
            name = currentThread().getName()
        else :
            name = currentThread().name
        if verbose:
            print "%s: creating records %d - %d" % (name, start, stop)

        count=len(keys)//len(readers)
        count2=count
        for x in keys :
            key = '%04d' % x
            dbutils.DeadlockWrap(d.put, key, self.makeData(key),
                                 max_retries=12)

            if verbose and x % 100 == 0:
                print "%s: records %d - %d finished" % (name, start, x)

            count2-=1
            if not count2 :
                readers.pop().start()
                count2=count

        if verbose:
            print "%s: thread finished" % name

    def readerThread(self, d, readerNum):
        if sys.version_info[0] < 3 :
            name = currentThread().getName()
        else :
            name = currentThread().name

        c = d.cursor()
        count = 0
        rec = dbutils.DeadlockWrap(c.first, max_retries=10)
        while rec:
            count += 1
            key, data = rec
            self.assertEqual(self.makeData(key), data)
            rec = dbutils.DeadlockWrap(c.next, max_retries=10)
        if verbose:
            print "%s: found %d records" % (name, count)
        c.close()

        if verbose:
            print "%s: thread finished" % name


class BTreeSimpleThreaded(SimpleThreadedBase):
    dbtype = db.DB_BTREE


class HashSimpleThreaded(SimpleThreadedBase):
    dbtype = db.DB_HASH


#----------------------------------------------------------------------


class ThreadedTransactionsBase(BaseThreadedTestCase):
    dbopenflags = db.DB_THREAD | db.DB_AUTO_COMMIT
    envflags    = (db.DB_THREAD |
                   db.DB_INIT_MPOOL |
                   db.DB_INIT_LOCK |
                   db.DB_INIT_LOG |
                   db.DB_INIT_TXN
                   )
    readers = 0
    writers = 0
    records = 2000
    txnFlag = 0

    def setEnvOpts(self):
        #self.env.set_lk_detect(db.DB_LOCK_DEFAULT)
        pass

    def test03_ThreadedTransactions(self):
        if verbose:
            print '\n', '-=' * 30
            print "Running %s.test03_ThreadedTransactions..." % \
                  self.__class__.__name__

        keys=range(self.records)
        import random
        random.shuffle(keys)
        records_per_writer=self.records//self.writers
        readers_per_writer=self.readers//self.writers
        self.assertEqual(self.records,self.writers*records_per_writer)
        self.assertEqual(self.readers,self.writers*readers_per_writer)
        self.assertTrue((records_per_writer%readers_per_writer)==0)

        readers=[]
        for x in xrange(self.readers):
            rt = Thread(target = self.readerThread,
                        args = (self.d, x),
                        name = 'reader %d' % x,
                        )#verbose = verbose)
            if sys.version_info[0] < 3 :
                rt.setDaemon(True)
            else :
                rt.daemon = True
            readers.append(rt)

        writers = []
        for x in xrange(self.writers):
            a=keys[records_per_writer*x:records_per_writer*(x+1)]
            b=readers[readers_per_writer*x:readers_per_writer*(x+1)]
            wt = Thread(target = self.writerThread,
                        args = (self.d, a, b),
                        name = 'writer %d' % x,
                        )#verbose = verbose)
            writers.append(wt)

        dt = Thread(target = self.deadlockThread)
        if sys.version_info[0] < 3 :
            dt.setDaemon(True)
        else :
            dt.daemon = True
        dt.start()

        for t in writers:
            if sys.version_info[0] < 3 :
                t.setDaemon(True)
            else :
                t.daemon = True
            t.start()

        for t in writers:
            t.join()
        for t in readers:
            t.join()

        self.doLockDetect = False
        dt.join()

    def writerThread(self, d, keys, readers):
        if sys.version_info[0] < 3 :
            name = currentThread().getName()
        else :
            name = currentThread().name

        count=len(keys)//len(readers)
        while len(keys):
            try:
                txn = self.env.txn_begin(None, self.txnFlag)
                keys2=keys[:count]
                for x in keys2 :
                    key = '%04d' % x
                    d.put(key, self.makeData(key), txn)
                    if verbose and x % 100 == 0:
                        print "%s: records %d - %d finished" % (name, start, x)
                txn.commit()
                keys=keys[count:]
                readers.pop().start()
            except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
                if verbose:
                    print "%s: Aborting transaction (%s)" % (name,
                            val.args[1])
                txn.abort()

        if verbose:
            print "%s: thread finished" % name

    def readerThread(self, d, readerNum):
        if sys.version_info[0] < 3 :
            name = currentThread().getName()
        else :
            name = currentThread().name

        finished = False
        while not finished:
            try:
                txn = self.env.txn_begin(None, self.txnFlag)
                c = d.cursor(txn)
                count = 0
                rec = c.first()
                while rec:
                    count += 1
                    key, data = rec
                    self.assertEqual(self.makeData(key), data)
                    rec = c.next()
                if verbose: print "%s: found %d records" % (name, count)
                c.close()
                txn.commit()
                finished = True
            except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
                if verbose:
                    print "%s: Aborting transaction (%s)" % (name,
                            val.args[1])
                c.close()
                txn.abort()

        if verbose:
            print "%s: thread finished" % name

    def deadlockThread(self):
        self.doLockDetect = True
        while self.doLockDetect:
            time.sleep(0.05)
            try:
                aborted = self.env.lock_detect(
                    db.DB_LOCK_RANDOM, db.DB_LOCK_CONFLICT)
                if verbose and aborted:
                    print "deadlock: Aborted %d deadlocked transaction(s)" \
                          % aborted
            except db.DBError:
                pass


class BTreeThreadedTransactions(ThreadedTransactionsBase):
    dbtype = db.DB_BTREE
    writers = 2
    readers = 10
    records = 1000

class HashThreadedTransactions(ThreadedTransactionsBase):
    dbtype = db.DB_HASH
    writers = 2
    readers = 10
    records = 1000

class BTreeThreadedNoWaitTransactions(ThreadedTransactionsBase):
    dbtype = db.DB_BTREE
    writers = 2
    readers = 10
    records = 1000
    txnFlag = db.DB_TXN_NOWAIT

class HashThreadedNoWaitTransactions(ThreadedTransactionsBase):
    dbtype = db.DB_HASH
    writers = 2
    readers = 10
    records = 1000
    txnFlag = db.DB_TXN_NOWAIT


#----------------------------------------------------------------------

def test_suite():
    suite = unittest.TestSuite()

    if have_threads:
        suite.addTest(unittest.makeSuite(BTreeConcurrentDataStore))
        suite.addTest(unittest.makeSuite(HashConcurrentDataStore))
        suite.addTest(unittest.makeSuite(BTreeSimpleThreaded))
        suite.addTest(unittest.makeSuite(HashSimpleThreaded))
        suite.addTest(unittest.makeSuite(BTreeThreadedTransactions))
        suite.addTest(unittest.makeSuite(HashThreadedTransactions))
        suite.addTest(unittest.makeSuite(BTreeThreadedNoWaitTransactions))
        suite.addTest(unittest.makeSuite(HashThreadedNoWaitTransactions))

    else:
        print "Threads not available, skipping thread tests."

    return suite


if __name__ == '__main__':
    unittest.main(defaultTest='test_suite')
