from __future__ import unicode_literals

import json
import select
import threading
import time
import uuid

import pika
import pytest

import pika_pool


@pytest.fixture(scope='session')
def params():
    return pika.URLParameters('amqp://guest:guest@localhost:5672/')


@pytest.fixture(scope='session', autouse=True)
def schema(request, params):
    cxn = pika.BlockingConnection(params)
    channel = cxn.channel()
    channel.queue_declare(queue='pika_pool_test')


consumed = {
}


@pytest.fixture(scope='session', autouse=True)
def consume(params):

    def _callback(ch, method, properties, body):
        msg = Message.from_json(body)
        consumed[msg.id] = msg

    def _forever():
        channel.start_consuming()

    cxn = pika.BlockingConnection(params)
    channel = cxn.channel()
    channel.queue_declare(queue='pika_pool_test')
    channel.basic_consume(_callback, queue='pika_pool_test', no_ack=True)

    thd = threading.Thread(target=_forever)
    thd.daemon = True
    thd.start()


@pytest.fixture
def null_pool(params):
    return pika_pool.NullPool(
        create=lambda: pika.BlockingConnection(params),
    )


class Message(dict):

    @classmethod
    def generate(cls, **kwargs):
        id = kwargs.pop('id', uuid.uuid4().hex)
        return cls(id=id, **kwargs)

    @property
    def id(self):
        return self['id']

    def to_json(self):
        return json.dumps(self)

    @classmethod
    def from_json(cls, raw):
        return cls(json.loads(raw.decode('utf-8')))


class TestNullPool(object):

    def test_pub(self, null_pool):
        msg = Message.generate()
        with null_pool.acquire() as cxn:
            cxn.channel.basic_publish(
                exchange='',
                routing_key='pika_pool_test',
                body=msg.to_json()
            )
        time.sleep(0.1)
        assert msg.id in consumed


@pytest.fixture
def queued_pool(params):
    return pika_pool.QueuedPool(
        create=lambda: pika.BlockingConnection(params),
        recycle=10,
        stale=10,
        max_size=10,
        max_overflow=10,
        timeout=10,
    )


@pytest.fixture
def empty_queued_pool(request, queued_pool):
    queued = [queued_pool.acquire() for _ in range(queued_pool.max_size)]
    request.addfinalizer(lambda: [cxn.release() for cxn in queued])
    overflow = [queued_pool.acquire() for _ in range(queued_pool.max_overflow)]
    request.addfinalizer(lambda: [cxn.release() for cxn in overflow])
    return queued_pool


def test_use_it():
    params = pika.URLParameters(
      'amqp://guest:guest@localhost:5672/?'
      'socket_timeout=10&'
      'connection_attempts=2'
    )

    pool = pika_pool.QueuedPool(
        create=lambda: pika.BlockingConnection(parameters=params),
        max_size=10,
        max_overflow=10,
        timeout=10,
        recycle=3600,
        stale=45,
    )

    with pool.acquire() as cxn:
        cxn.channel.basic_publish(
            body=json.dumps({
                'type': 'banana',
                'description': 'they are yellow'
            }),
            exchange='',
            routing_key='fruits',
            properties=pika.BasicProperties(
                content_type='application/json',
                content_encoding='utf-8',
                delivery_mode=2,
            )
        )


class TestQueuedPool(object):

    def test_invalidate_connection(slef, queued_pool):
        msg = Message.generate()
        with pytest.raises(select.error):
            with queued_pool.acquire() as cxn:
                fairy = cxn.fairy
                raise select.error(9, 'Bad file descriptor')
        assert fairy.cxn.is_closed

    def test_pub(self, queued_pool):
        msg = Message.generate()
        with queued_pool.acquire() as cxn:
            cxn.channel.basic_publish(
                exchange='',
                routing_key='pika_pool_test',
                body=msg.to_json()
            )
        time.sleep(0.1)
        assert msg.id in consumed

    def test_expire(self, queued_pool):
        with queued_pool.acquire() as cxn:
            expired = id(cxn.fairy.cxn)
            expires_at = cxn.fairy.created_at + queued_pool.recycle
        with queued_pool.acquire() as cxn:
            assert expired == id(cxn.fairy.cxn)
            cxn.fairy.created_at -= queued_pool.recycle
        with queued_pool.acquire() as cxn:
            assert expired != id(cxn.fairy.cxn)

    def test_stale(self, queued_pool):
        with queued_pool.acquire() as cxn:
            stale = id(cxn.fairy.cxn)
            fairy = cxn.fairy
        with queued_pool.acquire() as cxn:
            assert stale == id(cxn.fairy.cxn)
        fairy.released_at -= queued_pool.stale
        with queued_pool.acquire() as cxn:
            assert stale != id(cxn.fairy.cxn)

    def test_overflow(self, queued_pool):
        queued = [queued_pool.acquire() for _ in range(queued_pool.max_size)]
        with queued_pool.acquire() as cxn:
            fairy = cxn.fairy
            for cxn in queued:
                cxn.release()
        assert fairy.cxn.is_closed

    def test_timeout(self, empty_queued_pool):
        empty_queued_pool.timeout = 2
        st = time.time()
        with pytest.raises(pika_pool.Timeout):
            empty_queued_pool.acquire()
        elapsed = time.time() - st
        assert elapsed < 2.5

    def test_timeout_override(self, empty_queued_pool):
        st = time.time()
        with pytest.raises(pika_pool.Timeout):
            empty_queued_pool.acquire(timeout=1)
        elapsed = time.time() - st
        assert elapsed < 1.5
