"""Client for cache server.

See cachesvr.py for protocol description.
"""

import argparse
import trollius as asyncio
from trollius import From, Return
from trollius import test_utils
import json
import logging

ARGS = argparse.ArgumentParser(description='Cache client example.')
ARGS.add_argument(
    '--tls', action='store_true', dest='tls',
    default=False, help='Use TLS')
ARGS.add_argument(
    '--iocp', action='store_true', dest='iocp',
    default=False, help='Use IOCP event loop (Windows only)')
ARGS.add_argument(
    '--host', action='store', dest='host',
    default='localhost', help='Host name')
ARGS.add_argument(
    '--port', action='store', dest='port',
    default=54321, type=int, help='Port number')
ARGS.add_argument(
    '--timeout', action='store', dest='timeout',
    default=5, type=float, help='Timeout')
ARGS.add_argument(
    '--max_backoff', action='store', dest='max_backoff',
    default=5, type=float, help='Max backoff on reconnect')
ARGS.add_argument(
    '--ntasks', action='store', dest='ntasks',
    default=10, type=int, help='Number of tester tasks')
ARGS.add_argument(
    '--ntries', action='store', dest='ntries',
    default=5, type=int, help='Number of request tries before giving up')


args = ARGS.parse_args()


class CacheClient:
    """Multiplexing cache client.

    This wraps a single connection to the cache client.  The
    connection is automatically re-opened when an error occurs.

    Multiple tasks may share this object; the requests will be
    serialized.

    The public API is get(), set(), delete() (all are coroutines).
    """

    def __init__(self, host, port, sslctx=None, loop=None):
        self.host = host
        self.port = port
        self.sslctx = sslctx
        self.loop = loop
        self.todo = set()
        self.initialized = False
        self.task = asyncio.Task(self.activity(), loop=self.loop)

    @asyncio.coroutine
    def get(self, key):
        resp = yield From(self.request('get', key))
        if resp is None:
            raise Return()
        raise Return(resp.get('value'))

    @asyncio.coroutine
    def set(self, key, value):
        resp = yield From(self.request('set', key, value))
        if resp is None:
            raise Return(False)
        raise Return(resp.get('status') == 'ok')

    @asyncio.coroutine
    def delete(self, key):
        resp = yield From(self.request('delete', key))
        if resp is None:
            raise Return(False)
        raise Return(resp.get('status') == 'ok')

    @asyncio.coroutine
    def request(self, type, key, value=None):
        assert not self.task.done()
        data = {'type': type, 'key': key}
        if value is not None:
            data['value'] = value
        payload = json.dumps(data).encode('utf8')
        waiter = asyncio.Future(loop=self.loop)
        if self.initialized:
            try:
                yield From(self.send(payload, waiter))
            except IOError:
                self.todo.add((payload, waiter))
        else:
            self.todo.add((payload, waiter))
        result = (yield From(waiter))
        raise Return(result)

    @asyncio.coroutine
    def activity(self):
        backoff = 0
        while True:
            try:
                self.reader, self.writer = yield From(asyncio.open_connection(
                    self.host, self.port, ssl=self.sslctx, loop=self.loop))
            except Exception as exc:
                backoff = min(args.max_backoff, backoff + (backoff//2) + 1)
                logging.info('Error connecting: %r; sleep %s', exc, backoff)
                yield From(asyncio.sleep(backoff, loop=self.loop))
                continue
            backoff = 0
            self.next_id = 0
            self.pending = {}
            self. initialized = True
            try:
                while self.todo:
                    payload, waiter = self.todo.pop()
                    if not waiter.done():
                        yield From(self.send(payload, waiter))
                while True:
                    resp_id, resp = yield From(self.process())
                    if resp_id in self.pending:
                        payload, waiter = self.pending.pop(resp_id)
                        if not waiter.done():
                            waiter.set_result(resp)
            except Exception as exc:
                self.initialized = False
                self.writer.close()
                while self.pending:
                    req_id, pair = self.pending.popitem()
                    payload, waiter = pair
                    if not waiter.done():
                        self.todo.add(pair)
                logging.info('Error processing: %r', exc)

    @asyncio.coroutine
    def send(self, payload, waiter):
        self.next_id += 1
        req_id = self.next_id
        frame = 'request %d %d\n' % (req_id, len(payload))
        self.writer.write(frame.encode('ascii'))
        self.writer.write(payload)
        self.pending[req_id] = payload, waiter
        yield From(self.writer.drain())

    @asyncio.coroutine
    def process(self):
        frame = yield From(self.reader.readline())
        if not frame:
            raise EOFError()
        head, tail = frame.split(None, 1)
        if head == b'error':
            raise IOError('OOB error: %r' % tail)
        if head != b'response':
            raise IOError('Bad frame: %r' % frame)
        resp_id, resp_size = map(int, tail.split())
        data = yield From(self.reader.readexactly(resp_size))
        if len(data) != resp_size:
            raise EOFError()
        resp = json.loads(data.decode('utf8'))
        raise Return(resp_id, resp)


def main():
    asyncio.set_event_loop(None)
    if args.iocp:
        from trollius.windows_events import ProactorEventLoop
        loop = ProactorEventLoop()
    else:
        loop = asyncio.new_event_loop()
    sslctx = None
    if args.tls:
        sslctx = test_utils.dummy_ssl_context()
    cache = CacheClient(args.host, args.port, sslctx=sslctx, loop=loop)
    try:
        loop.run_until_complete(
            asyncio.gather(
                *[testing(i, cache, loop) for i in range(args.ntasks)],
                loop=loop))
    finally:
        loop.close()


@asyncio.coroutine
def testing(label, cache, loop):

    def w(g):
        return asyncio.wait_for(g, args.timeout, loop=loop)

    key = 'foo-%s' % label
    while True:
        logging.info('%s %s', label, '-'*20)
        try:
            ret = yield From(w(cache.set(key, 'hello-%s-world' % label)))
            logging.info('%s set %s', label, ret)
            ret = yield From(w(cache.get(key)))
            logging.info('%s get %s', label, ret)
            ret = yield From(w(cache.delete(key)))
            logging.info('%s del %s', label, ret)
            ret = yield From(w(cache.get(key)))
            logging.info('%s get2 %s', label, ret)
        except asyncio.TimeoutError:
            logging.warn('%s Timeout', label)
        except Exception as exc:
            logging.exception('%s Client exception: %r', label, exc)
            break


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    main()
