# Written by Bram Cohen
# see LICENSE.txt for license information

from bisect import insort
import socket
from cStringIO import StringIO
from traceback import print_exc
from errno import EWOULDBLOCK, EINTR
try:
    from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
    timemult = 1000
except ImportError:
    from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
    timemult = 1
from threading import Thread, Event
from time import time, sleep
import sys
from random import randrange
from traceback import print_stack

all = POLLIN | POLLOUT

class SingleSocket:
    def __init__(self, raw_server, sock, handler):
        self.raw_server = raw_server
        self.socket = sock
        self.handler = handler
        self.buffer = []
        self.last_hit = time()
        self.fileno = sock.fileno()
        self.connected = False

    def get_ip(self):
        try:
            return self.socket.getpeername()[0]
        except socket.error:
            return 'no connection'

    def close(self, unregister=True):
        # print 'RawServer close: ' + str(self.fileno)
        # print_stack()
        sock = self.socket
        self.socket = None
        self.buffer = []
        if unregister:
            self.raw_server.unregister(self.fileno)
        sock.close()

    def shutdown(self, val):
        self.socket.shutdown(val)

    def is_flushed(self):
        return len(self.buffer) == 0

    def write(self, s):
        assert self.socket is not None
        self.buffer.append(s)
        if self.connected and len(self.buffer) == 1:
            self.try_write()
        return not len(self.buffer)

    def try_write(self):
        if self.connected:
            try:
                while self.buffer != []:
                    amount = self.socket.send(self.buffer[0])
                    if amount != len(self.buffer[0]):
                        if amount != 0:
                            self.buffer[0] = self.buffer[0][amount:]
                        break
                    del self.buffer[0]
            except socket.error, e:
                code, msg = e
                if code != EWOULDBLOCK:
                    self.raw_server.dead_from_write.append(self)
                    return
        if self.buffer == []:
            self.raw_server.poll.register(self.socket, POLLIN)
        else:
            self.raw_server.poll.register(self.socket, all)

class RawServer:
    def __init__(self, doneflag, timeout_check_interval, timeout):
        self.timeout_check_interval = timeout_check_interval
        self.timeout = timeout
        self.poll = poll()
        # {socket: SingleSocket}
        self.single_sockets = {}
        self.dead_from_write = []
        self.doneflag = doneflag
        self.funcs = []
        self.externally_added = []
        self.server = None
        self.add_task(self.scan_for_timeouts, timeout_check_interval)

    def add_task(self, func, delay):
        insort(self.funcs, (time() + delay, func))

    def external_add_task(self, func, delay = 0):
        self.externally_added.append((func, delay))

    def scan_for_timeouts(self):
        self.add_task(self.scan_for_timeouts, self.timeout_check_interval)
        t = time() - self.timeout
        tokill = []
        for s in self.single_sockets.values():
            if s.last_hit < t:
                tokill.append(s)
        for k in tokill:
            if k.socket is not None:
                self._close_socket(k)

    def bind(self, port, bind = '', reuse = False):
        server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        if reuse:
            server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        server.setblocking(0)
        server.bind((bind, port))
        server.listen(5)
        self.poll.register(server, POLLIN)
        self.server = server

    def start_connection(self, dns, handler = None):
        if handler is None:
            handler = self.handler
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setblocking(0)
        try:
            sock.connect_ex(dns)
        except socket.error:
            raise
        except Exception, e:
            raise socket.error(str(e))
        self.poll.register(sock, POLLIN|POLLOUT|POLLERR)
        s = SingleSocket(self, sock, handler)
        self.single_sockets[sock.fileno()] = s
        return s

    def handle_events(self, events):
        for sock, event in events:
            if self.server is not None and sock == self.server.fileno():
                if event & (POLLHUP | POLLERR) != 0:
                    self.poll.unregister(self.server)
                    self.server.close()
                    self.errorfunc('lost server socket')
                else:
                    try:
                        newsock, addr = self.server.accept()
                        newsock.setblocking(0)
                        nss = SingleSocket(self, newsock, self.handler)
                        self.single_sockets[newsock.fileno()] = nss
                        self.poll.register(newsock, POLLIN)
                        self.handler.external_connection_made(nss)
                    except socket.error:
                        sleep(1)
            else:
                s = self.single_sockets.get(sock)
                if s is None:
                    continue
                if (event & (POLLHUP|POLLERR)) != 0:
                    if s.connected:
                        self._close_socket(s, msg='Remote end closed connection')
                    else:
                        self._close_socket(s, msg='Connection refused')
                    continue
                s.connected = True
                if (event & POLLIN) != 0:
                    try:
                        s.last_hit = time()
                        data = s.socket.recv(100000)
                        if data == '':
                            self._close_socket(s, msg='Remote end closed connection')
                        else:
                            s.handler.data_came_in(s, data)
                    except socket.error, e:
                        code, msg = e
                        if code != EWOULDBLOCK:
                            self._close_socket(s, msg=msg)
                            continue
                if (event & POLLOUT) != 0 and s.socket is not None and not s.is_flushed():
                    s.try_write()
                    if s.is_flushed():
                        s.handler.connection_flushed(s)

    def pop_external(self):
        try:
            while True:
                (a, b) = self.externally_added.pop()
                self.add_task(a, b)
        except IndexError:
            pass

    def listen_forever(self, handler):
        self.handler = handler
        try:
            while not self.doneflag.isSet():
                try:
                    self.pop_external()
                    if len(self.funcs) == 0:
                        period = 2 ** 30
                    else:
                        period = self.funcs[0][0] - time()
                    if period < 0:
                        period = 0
                    events = self.poll.poll(period * timemult)
                    if self.doneflag.isSet():
                        return
                    while len(self.funcs) > 0 and self.funcs[0][0] <= time():
                        garbage, func = self.funcs[0]
                        del self.funcs[0]
                        try:
                            func()
                        except KeyboardInterrupt:
                            print_exc()
                            return
                    self._close_dead()
                    self.handle_events(events)
                    if self.doneflag.isSet():
                        return
                    self._close_dead()
                except error:
                    if self.doneflag.isSet():
                        return
                #except KeyboardInterrupt:
                #    print_exc()
                #    return
#                except:
#                    data = StringIO()
#                    print_exc(file = data)
#                    self.errorfunc(data.getvalue())
        finally:
#            for ss in self.single_sockets.values():
#                ss.close()
            if self.server is not None:
                self.poll.unregister(self.server)
                self.server.close()

    def unregister(self, fd):
        del self.single_sockets[fd]
        self.poll.unregister(fd)
        return

    def _close_dead(self):
        while len(self.dead_from_write) > 0:
            old = self.dead_from_write
            self.dead_from_write = []
            for s in old:
                if s.socket is not None:
                    self._close_socket(s)

    def _close_socket(self, s, msg=None):
        sock = s.socket.fileno()
        self.poll.unregister(sock)
        del self.single_sockets[sock]
        s.socket.close()
        s.socket = None
        s.handler.connection_lost(s, msg)

# everything below is for testing

class DummyHandler:
    def __init__(self):
        self.external_made = []
        self.data_in = []
        self.lost = []

    def external_connection_made(self, s):
        self.external_made.append(s)
    
    def data_came_in(self, s, data):
        self.data_in.append((s, data))
    
    def connection_lost(self, s, msg):
        self.lost.append(s)

    def connection_flushed(self, s):
        pass

def sl(rs, handler, port):
    rs.bind(port)
    Thread(target = rs.listen_forever, args = [handler]).start()

def loop(rs):
    x = []
    def r(rs = rs, x = x):
        rs.add_task(x[0], .1)
    x.append(r)
    rs.add_task(r, .1)

beginport = 5000 + randrange(10000)

def test_starting_side_close():
    try:
        fa = Event()
        fb = Event()
        da = DummyHandler()
        sa = RawServer(fa, 100, 100)
        loop(sa)
        sl(sa, da, beginport)
        db = DummyHandler()
        sb = RawServer(fb, 100, 100)
        loop(sb)
        sl(sb, db, beginport + 1)

        sleep(.5)
        ca = sa.start_connection(('127.0.0.1', beginport + 1))
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == []
        assert da.lost == []
        assert len(db.external_made) == 1
        cb = db.external_made[0]
        del db.external_made[:]
        assert db.data_in == []
        assert db.lost == []

        ca.write('aaa')
        cb.write('bbb')
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == [(ca, 'bbb')]
        del da.data_in[:]
        assert da.lost == []
        assert db.external_made == []
        assert db.data_in == [(cb, 'aaa')]
        del db.data_in[:]
        assert db.lost == []

        ca.write('ccc')
        cb.write('ddd')
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == [(ca, 'ddd')]
        del da.data_in[:]
        assert da.lost == []
        assert db.external_made == []
        assert db.data_in == [(cb, 'ccc')]
        del db.data_in[:]
        assert db.lost == []

        ca.close()
        sleep(1)

        assert da.external_made == []
        assert da.data_in == []
        assert da.lost == []
        assert db.external_made == []
        assert db.data_in == []
        assert db.lost == [cb]
        del db.lost[:]
    finally:
        fa.set()
        fb.set()

def test_receiving_side_close():
    try:
        da = DummyHandler()
        fa = Event()
        sa = RawServer(fa, 100, 100)
        loop(sa)
        sl(sa, da, beginport + 2)
        db = DummyHandler()
        fb = Event()
        sb = RawServer(fb, 100, 100)
        loop(sb)
        sl(sb, db, beginport + 3)
        
        sleep(.5)
        ca = sa.start_connection(('127.0.0.1', beginport + 3))
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == []
        assert da.lost == []
        assert len(db.external_made) == 1
        cb = db.external_made[0]
        del db.external_made[:]
        assert db.data_in == []
        assert db.lost == []

        ca.write('aaa')
        cb.write('bbb')
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == [(ca, 'bbb')]
        del da.data_in[:]
        assert da.lost == []
        assert db.external_made == []
        assert db.data_in == [(cb, 'aaa')]
        del db.data_in[:]
        assert db.lost == []

        ca.write('ccc')
        cb.write('ddd')
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == [(ca, 'ddd')]
        del da.data_in[:]
        assert da.lost == []
        assert db.external_made == []
        assert db.data_in == [(cb, 'ccc')]
        del db.data_in[:]
        assert db.lost == []

        cb.close()
        sleep(1)

        assert da.external_made == []
        assert da.data_in == []
        assert da.lost == [ca]
        del da.lost[:]
        assert db.external_made == []
        assert db.data_in == []
        assert db.lost == []
    finally:
        fa.set()
        fb.set()

def test_connection_refused():
    try:
        da = DummyHandler()
        fa = Event()
        sa = RawServer(fa, 100, 100)
        loop(sa)
        sl(sa, da, beginport + 6)

        sleep(.5)
        ca = sa.start_connection(('127.0.0.1', beginport + 15))
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == []
        assert da.lost == [ca]
        del da.lost[:]
    finally:
        fa.set()

def test_both_close():
    try:
        da = DummyHandler()
        fa = Event()
        sa = RawServer(fa, 100, 100)
        loop(sa)
        sl(sa, da, beginport + 4)

        sleep(1)
        db = DummyHandler()
        fb = Event()
        sb = RawServer(fb, 100, 100)
        loop(sb)
        sl(sb, db, beginport + 5)

        sleep(.5)
        ca = sa.start_connection(('127.0.0.1', beginport + 5))
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == []
        assert da.lost == []
        assert len(db.external_made) == 1
        cb = db.external_made[0]
        del db.external_made[:]
        assert db.data_in == []
        assert db.lost == []

        ca.write('aaa')
        cb.write('bbb')
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == [(ca, 'bbb')]
        del da.data_in[:]
        assert da.lost == []
        assert db.external_made == []
        assert db.data_in == [(cb, 'aaa')]
        del db.data_in[:]
        assert db.lost == []

        ca.write('ccc')
        cb.write('ddd')
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == [(ca, 'ddd')]
        del da.data_in[:]
        assert da.lost == []
        assert db.external_made == []
        assert db.data_in == [(cb, 'ccc')]
        del db.data_in[:]
        assert db.lost == []

        sa.unregister(ca.fileno)
        sb.unregister(cb.fileno)

        ca.close(unregister=False)
        cb.close(unregister=False)
        sleep(1)

        assert da.external_made == []
        assert da.data_in == []
        assert da.lost == []
        assert db.external_made == []
        assert db.data_in == []
        assert db.lost == []
    finally:
        fa.set()
        fb.set()

def test_normal():
    l = []
    f = Event()
    s = RawServer(f, 100, 100)
    loop(s)
    sl(s, DummyHandler(), beginport + 7)
    s.add_task(lambda l = l: l.append('b'), 2)
    s.add_task(lambda l = l: l.append('a'), 1)
    s.add_task(lambda l = l: l.append('d'), 4)
    sleep(1.5)
    s.add_task(lambda l = l: l.append('c'), 1.5)
    sleep(3)
    assert l == ['a', 'b', 'c', 'd']
    f.set()

def test_catch_exception():
    l = []
    f = Event()
    s = RawServer(f, 100, 100, False)
    loop(s)
    sl(s, DummyHandler(), beginport + 9)
    s.add_task(lambda l = l: l.append('b'), 2)
    s.add_task(lambda: 4/0, 1)
    sleep(3)
    assert l == ['b']
    f.set()

def test_closes_if_not_hit():
    try:
        da = DummyHandler()
        fa = Event()
        sa = RawServer(fa, 2, 2)
        loop(sa)
        sl(sa, da, beginport + 14)

        sleep(1)
        db = DummyHandler()
        fb = Event()
        sb = RawServer(fb, 100, 100)
        loop(sb)
        sl(sb, db, beginport + 13)
        
        sleep(.5)
        sa.start_connection(('127.0.0.1', beginport + 13))
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == []
        assert da.lost == []
        assert len(db.external_made) == 1
        del db.external_made[:]
        assert db.data_in == []
        assert db.lost == []

        sleep(3.1)
        
        assert len(da.lost) == 1
        assert len(db.lost) == 1
    finally:
        fa.set()
        fb.set()

def test_does_not_close_if_hit():
    try:
        fa = Event()
        fb = Event()
        da = DummyHandler()
        sa = RawServer(fa, 2, 2)
        loop(sa)
        sl(sa, da, beginport + 12)

        sleep(1)
        db = DummyHandler()
        sb = RawServer(fb, 100, 100)
        loop(sb)
        sl(sb, db, beginport + 13)
        
        sleep(.5)
        sa.start_connection(('127.0.0.1', beginport + 13))
        sleep(1)
        
        assert da.external_made == []
        assert da.data_in == []
        assert da.lost == []
        assert len(db.external_made) == 1
        cb = db.external_made[0]
        del db.external_made[:]
        assert db.data_in == []
        assert db.lost == []

        cb.write('bbb')
        sleep(.5)
        
        assert da.lost == []
        assert db.lost == []
    finally:
        fa.set()
        fb.set()
