try:
    from _thread import start_new_thread
except:
    from thread import start_new_thread

import os, sys
import logging
import traceback
import time
from multiprocessing import Process

from jobqueue import runjob, store
from jobqueue.client import connect

store.ROOT = "/tmp/worker/%s"
DEFAULT_DISPATCHER = "http://reflectometry.org/queue"
POLLRATE = 10


def log_errors(f):
    def wrapped(*args, **kw):
        try:
            return f(*args, **kw)
        except:
            exc_type, exc_value, exc_trace = sys.exc_info()
            trace = traceback.format_tb(exc_trace)
            message = traceback.format_exception_only(exc_type, exc_value)
            logging.error(message + trace)

    return wrapped


def wait_for_result(remote, id, process, queue):
    """
    Wait for job processing to finish.  Meanwhile, prefetch the next
    request.
    """
    next_request = {"request": None}
    canceling = False
    while True:
        # Check if process is complete
        process.join(POLLRATE)
        if not process.is_alive():
            break

        # Check that the job is still active, and that it hasn't been
        # canceled, or results reported back from a second worker.
        # If remote server is down, assume the job is still active.
        try:
            response = remote.status(id)
        except:
            response = None
        if response and response["status"] != "ACTIVE":
            # print "canceling process"
            process.terminate()
            canceling = True
            break

        # Prefetch the next job; this strategy works well if there is
        # only one worker.  If there are many, we may want to leave it
        # for another worker to process.
        if not next_request["request"]:
            # Ignore remote server down errors
            try:
                next_request = remote.nextjob(queue=queue)
            except:
                pass

    # Grab results from the store
    try:
        results = runjob.results(id)
    except KeyError:
        if canceling:
            results = {"status": "CANCEL", "message": "Job canceled"}
        else:
            results = {"status": "ERROR", "message": "Results not found"}

    # print "returning results",results
    return results, next_request


@log_errors
def update_remote(dispatcher, id, queue, results):
    """
    Update remote server with results.
    """
    # print "updating remote"
    path = store.path(id)
    # Remove results key, if it is there
    try:
        store.delete(id, "results")
    except KeyError:
        pass
    files = [os.path.join(path, f) for f in os.listdir(path)]
    # print "sending results",results
    # This is done with a separate connection to the server so that it can
    # run inside a thread.  That way the server can start the next job
    # while the megabytes of results are being transfered in the background.
    private_remote = connect(dispatcher)
    private_remote.postjob(id=id, results=results, queue=queue, files=files)
    # Clean up files
    for f in files:
        os.unlink(f)
    os.rmdir(path)


def serve(dispatcher, queue):
    """
    Run the work server.
    """
    assert queue is not None
    next_request = {"request": None}
    remote = connect(dispatcher)
    while True:
        if not next_request["request"]:
            try:
                next_request = remote.nextjob(queue=queue)
            except:
                logging.error(traceback.format_exc())
        if next_request["request"]:
            jobid = next_request["id"]
            if jobid is None:
                logging.error("request has no job id")
                next_request = {"request": None}
                continue
            logging.info("processing job %s" % jobid)
            process = Process(target=runjob.run, args=(jobid, next_request["request"]))
            process.start()
            results, next_request = wait_for_result(remote, jobid, process, queue)
            start_new_thread(update_remote, (dispatcher, jobid, queue, results))
        else:
            time.sleep(POLLRATE)


def main():
    try:
        os.nice(19)
    except:
        pass
    if len(sys.argv) <= 1:
        print("Requires queue name")
    queue = sys.argv[1]
    dispatcher = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_DISPATCHER
    serve(queue=queue, dispatcher=dispatcher)


if __name__ == "__main__":
    main()
