import sys
import os
import os.path as op
import tempfile
from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError
from srsly.cloudpickle.compat import pickle
from contextlib import contextmanager
from concurrent.futures import ProcessPoolExecutor

import psutil
from srsly.cloudpickle import dumps
from subprocess import TimeoutExpired

loads = pickle.loads
TIMEOUT = 60
TEST_GLOBALS = "a test value"


def make_local_function():
    def g(x):
        # this function checks that the globals are correctly handled and that
        # the builtins are available
        assert TEST_GLOBALS == "a test value"
        return sum(range(10))

    return g


def _make_cwd_env():
    """Helper to prepare environment for the child processes"""
    cloudpickle_repo_folder = op.normpath(
        op.join(op.dirname(__file__), '..'))
    env = os.environ.copy()
    pythonpath = "{src}{sep}tests{pathsep}{src}".format(
        src=cloudpickle_repo_folder, sep=os.sep, pathsep=os.pathsep)
    env['PYTHONPATH'] = pythonpath
    return cloudpickle_repo_folder, env


def subprocess_pickle_string(input_data, protocol=None, timeout=TIMEOUT,
                             add_env=None):
    """Retrieve pickle string of an object generated by a child Python process

    Pickle the input data into a buffer, send it to a subprocess via
    stdin, expect the subprocess to unpickle, re-pickle that data back
    and send it back to the parent process via stdout for final unpickling.

    >>> testutils.subprocess_pickle_string([1, 'a', None], protocol=2)
    b'\x80\x02]q\x00(K\x01X\x01\x00\x00\x00aq\x01Ne.'

    """
    # run then pickle_echo(protocol=protocol) in __main__:

    # Protect stderr from any warning, as we will assume an error will happen
    # if it is not empty. A concrete example is pytest using the imp module,
    # which is deprecated in python 3.8
    cmd = [sys.executable, '-W ignore', __file__, "--protocol", str(protocol)]
    cwd, env = _make_cwd_env()
    if add_env:
        env.update(add_env)
    proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env,
                 bufsize=4096)
    pickle_string = dumps(input_data, protocol=protocol)
    try:
        comm_kwargs = {}
        comm_kwargs['timeout'] = timeout
        out, err = proc.communicate(pickle_string, **comm_kwargs)
        if proc.returncode != 0 or len(err):
            message = "Subprocess returned %d: " % proc.returncode
            message += err.decode('utf-8')
            raise RuntimeError(message)
        return out
    except TimeoutExpired as e:
        proc.kill()
        out, err = proc.communicate()
        message = "\n".join([out.decode('utf-8'), err.decode('utf-8')])
        raise RuntimeError(message) from e


def subprocess_pickle_echo(input_data, protocol=None, timeout=TIMEOUT,
                           add_env=None):
    """Echo function with a child Python process
    Pickle the input data into a buffer, send it to a subprocess via
    stdin, expect the subprocess to unpickle, re-pickle that data back
    and send it back to the parent process via stdout for final unpickling.
    >>> subprocess_pickle_echo([1, 'a', None])
    [1, 'a', None]
    """
    out = subprocess_pickle_string(input_data,
                                   protocol=protocol,
                                   timeout=timeout,
                                   add_env=add_env)
    return loads(out)


def _read_all_bytes(stream_in, chunk_size=4096):
    all_data = b""
    while True:
        data = stream_in.read(chunk_size)
        all_data += data
        if len(data) < chunk_size:
            break
    return all_data


def pickle_echo(stream_in=None, stream_out=None, protocol=None):
    """Read a pickle from stdin and pickle it back to stdout"""
    if stream_in is None:
        stream_in = sys.stdin
    if stream_out is None:
        stream_out = sys.stdout

    # Force the use of bytes streams under Python 3
    if hasattr(stream_in, 'buffer'):
        stream_in = stream_in.buffer
    if hasattr(stream_out, 'buffer'):
        stream_out = stream_out.buffer

    input_bytes = _read_all_bytes(stream_in)
    stream_in.close()
    obj = loads(input_bytes)
    repickled_bytes = dumps(obj, protocol=protocol)
    stream_out.write(repickled_bytes)
    stream_out.close()


def call_func(payload, protocol):
    """Remote function call that uses cloudpickle to transport everthing"""
    func, args, kwargs = loads(payload)
    try:
        result = func(*args, **kwargs)
    except BaseException as e:
        result = e
    return dumps(result, protocol=protocol)


class _Worker:
    def __init__(self, protocol=None):
        self.protocol = protocol
        self.pool = ProcessPoolExecutor(max_workers=1)
        self.pool.submit(id, 42).result()  # start the worker process

    def run(self, func, *args, **kwargs):
        """Synchronous remote function call"""

        input_payload = dumps((func, args, kwargs), protocol=self.protocol)
        result_payload = self.pool.submit(
            call_func, input_payload, self.protocol).result()
        result = loads(result_payload)

        if isinstance(result, BaseException):
            raise result
        return result

    def memsize(self):
        workers_pids = [p.pid if hasattr(p, "pid") else p
                        for p in list(self.pool._processes)]
        num_workers = len(workers_pids)
        if num_workers == 0:
            return 0
        elif num_workers > 1:
            raise RuntimeError("Unexpected number of workers: %d"
                               % num_workers)
        return psutil.Process(workers_pids[0]).memory_info().rss

    def close(self):
        self.pool.shutdown(wait=True)


@contextmanager
def subprocess_worker(protocol=None):
    worker = _Worker(protocol=protocol)
    yield worker
    worker.close()


def assert_run_python_script(source_code, timeout=TIMEOUT):
    """Utility to help check pickleability of objects defined in __main__

    The script provided in the source code should return 0 and not print
    anything on stderr or stdout.
    """
    fd, source_file = tempfile.mkstemp(suffix='_src_test_cloudpickle.py')
    os.close(fd)
    try:
        with open(source_file, 'wb') as f:
            f.write(source_code.encode('utf-8'))
        cmd = [sys.executable, '-W ignore', source_file]
        cwd, env = _make_cwd_env()
        kwargs = {
            'cwd': cwd,
            'stderr': STDOUT,
            'env': env,
        }
        # If coverage is running, pass the config file to the subprocess
        coverage_rc = os.environ.get("COVERAGE_PROCESS_START")
        if coverage_rc:
            kwargs['env']['COVERAGE_PROCESS_START'] = coverage_rc
        kwargs['timeout'] = timeout
        try:
            try:
                out = check_output(cmd, **kwargs)
            except CalledProcessError as e:
                raise RuntimeError("script errored with output:\n%s"
                                   % e.output.decode('utf-8')) from e
            if out != b"":
                raise AssertionError(out.decode('utf-8'))
        except TimeoutExpired as e:
            raise RuntimeError("script timeout, output so far:\n%s"
                               % e.output.decode('utf-8')) from e
    finally:
        os.unlink(source_file)


if __name__ == '__main__':
    protocol = int(sys.argv[sys.argv.index('--protocol') + 1])
    pickle_echo(protocol=protocol)
