# ClusterShell.Worker.ExecWorker test suite
# First version by A. Degremont 2014-07-10

"""Unit test for ExecWorker"""

import os
import unittest

from TLib import HOSTNAME, make_temp_file, make_temp_filename, make_temp_dir

from ClusterShell.Event import EventHandler
from ClusterShell.Worker.Exec import ExecWorker, WorkerError
from ClusterShell.Task import task_self

class ExecTest(unittest.TestCase):

    def execw(self, **kwargs):
        """helper method to spawn and run ExecWorker"""
        worker = ExecWorker(**kwargs)
        task_self().schedule(worker)
        task_self().run()
        return worker

    def test_no_nodes(self):
        """test ExecWorker with a simple command without nodes"""
        self.execw(nodes=None, handler=None, command="echo ok")
        self.assertEqual(task_self().max_retcode(), None)

    def test_shell_syntax(self):
        """test ExecWorker with a command using shell syntax"""
        cmd = "echo -n 1; echo -n 2"
        self.execw(nodes='localhost', handler=None, command=cmd)
        self.assertEqual(task_self().max_retcode(), 0)
        self.assertEqual(task_self().node_buffer('localhost'), b'12')

    def test_one_node(self):
        """test ExecWorker with a simple command on localhost"""
        self.execw(nodes='localhost', handler=None, command="echo ok")
        self.assertEqual(task_self().max_retcode(), 0)
        self.assertEqual(task_self().node_buffer('localhost'), b'ok')

    def test_one_node_error(self):
        """test ExecWorker with an error command on localhost"""
        self.execw(nodes='localhost', handler=None, command="false")
        self.assertEqual(task_self().max_retcode(), 1)
        self.assertEqual(task_self().node_buffer('localhost'), b'')

    @unittest.skipIf(HOSTNAME == 'localhost', "does not work with hostname set to 'localhost'")
    def test_timeout(self):
        """test ExecWorker with a timeout"""
        nodes = "localhost,%s" % HOSTNAME
        self.execw(nodes=nodes, handler=None, command="sleep 1", timeout=0.2)
        self.assertEqual(task_self().max_retcode(), None)
        self.assertEqual(task_self().num_timeout(), 2)

    def test_node_placeholder(self):
        """test ExecWorker with several nodes and %h (host)"""
        nodes = "localhost,%s" % HOSTNAME
        self.execw(nodes=nodes, handler=None, command="echo %h")
        self.assertEqual(task_self().max_retcode(), 0)
        self.assertEqual(task_self().node_buffer('localhost'), b'localhost')
        self.assertEqual(task_self().node_buffer(HOSTNAME), HOSTNAME.encode('utf-8'))

    def test_bad_placeholder(self):
        """test ExecWorker with unknown placeholder pattern"""
        self.assertRaises(WorkerError, self.execw,
                          nodes="localhost", handler=None, command="echo %x")
        self.assertRaises(WorkerError, self.execw,
                          nodes="localhost", handler=None, command="echo %")

    @unittest.skipIf(HOSTNAME == 'localhost', "does not work with hostname set to 'localhost'")
    def test_rank_placeholder(self):
        """test ExecWorker with several nodes and %n (rank)"""
        nodes = "localhost,%s" % HOSTNAME
        self.execw(nodes=nodes, handler=None, command="echo %n")
        self.assertEqual(task_self().max_retcode(), 0)
        self.assertEqual(set(bytes(msg) for msg, _ in task_self().iter_buffers()),
                         set([b'0', b'1']))

    def test_copy(self):
        """test copying with an ExecWorker and host placeholder"""
        src = make_temp_file(b"data")
        dstdir = make_temp_dir()
        dstpath = os.path.join(dstdir.name, os.path.basename(src.name))
        try:
            pattern = dstpath + ".%h"
            self.execw(nodes='localhost', handler=None, source=src.name,
                       dest=pattern)
            self.assertEqual(task_self().max_retcode(), 0)
            self.assertTrue(os.path.isfile(dstpath + '.localhost'))
        finally:
            os.unlink(dstpath + '.localhost')
            dstdir.cleanup()

    def test_copy_preserve(self):
        """test copying with an ExecWorker (preserve=True)"""
        src = make_temp_file(b"data")
        past_time = 443757600
        os.utime(src.name, (past_time, past_time))
        dstpath = make_temp_filename()
        try:
            self.execw(nodes='localhost', handler=None, source=src.name,
                       dest=dstpath, preserve=True)
            self.assertEqual(task_self().max_retcode(), 0)
            self.assertTrue(os.stat(dstpath).st_mtime, past_time)
        finally:
            os.unlink(dstpath)

    def test_copy_directory(self):
        """test copying directory with an ExecWorker"""
        srcdir = make_temp_dir()
        dstdir = make_temp_dir()
        ref1 = make_temp_file(b"data1", dir=srcdir.name)
        pathdstsrcdir = os.path.join(dstdir.name, os.path.basename(srcdir.name))
        pathdst1 = os.path.join(pathdstsrcdir, os.path.basename(ref1.name))
        try:
            self.execw(nodes='localhost', handler=None, source=srcdir.name,
                       dest=dstdir.name)
            self.assertEqual(task_self().max_retcode(), 0)
            self.assertTrue(os.path.isdir(pathdstsrcdir))
            self.assertTrue(os.path.isfile(pathdst1))
            with open(pathdst1) as dst1:
                self.assertEqual(dst1.readlines()[0], "data1")
        finally:
            os.unlink(pathdst1)
            os.rmdir(pathdstsrcdir)
            ref1.close()
            dstdir.cleanup()
            srcdir.cleanup()

    def test_copy_wrong_directory(self):
        """test copying wrong directory with an ExecWorker"""
        srcdir = make_temp_dir()
        dst = make_temp_file(b"data")
        ref1 = make_temp_file(b"data1", dir=srcdir.name)
        try:
            self.execw(nodes='localhost', handler=None, source=srcdir.name,
                       dest=dst.name, stderr=True)
            self.assertEqual(task_self().max_retcode(), 1)
            self.assertTrue(len(task_self().node_error("localhost")) > 0)
            self.assertTrue(os.path.isfile(ref1.name))
        finally:
            ref1.close()
            srcdir.cleanup()

    def test_rcopy_wrong_directory(self):
        """test ExecWorker reverse copying with wrong directory"""
        with make_temp_dir() as dstbasedirname:
            dstdir = os.path.join(dstbasedirname, "wrong")
            src = make_temp_file(b"data")
            self.assertRaises(ValueError, self.execw, nodes='localhost',
                              handler=None, source=src.name, dest=dstdir,
                              stderr=True, reverse=True)

    def test_abort_on_read(self):
        """test ExecWorker.abort() on read"""

        class TestH(EventHandler):
            def ev_read(self, worker, node, sname, msg):
                worker.abort()
                worker.abort()  # safe but no effect

        self.execw(nodes='localhost', handler=TestH(),
                   command="echo ok; tail -f /dev/null")
        self.assertEqual(task_self().max_retcode(), None)
        self.assertEqual(task_self().node_buffer('localhost'), b'ok')

    def test_abort_on_close(self):
        """test ExecWorker.abort() on close"""

        class TestH(EventHandler):
            def ev_close(self, worker, timedout):
                worker.abort()
                worker.abort()  # safe but no effect

        self.execw(nodes='localhost', handler=TestH(),
                   command="echo ok; sleep .1")
        self.assertEqual(task_self().max_retcode(), 0)
        self.assertEqual(task_self().node_buffer('localhost'), b'ok')
