1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
|
# Copyright (c) str4d <str4d@mail.i2p>
# See COPYING for details.
from builtins import object
try:
# Python 3
from unittest.mock import Mock
except:
# Python 2 (library)
from mock import Mock
from twisted.internet import defer
from twisted.internet.error import ConnectionLost, ConnectionRefusedError
from twisted.internet.protocol import ClientFactory
from twisted.python import failure
from twisted.test import proto_helpers
from txi2p.sam import constants as c
connectionLostFailure = failure.Failure(ConnectionLost())
connectionRefusedFailure = failure.Failure(ConnectionRefusedError())
class SAMProtocolTestMixin(object):
def makeProto(self, *a, **kw):
protoClass = kw.pop('_protoClass', self.protocol)
fac = ClientFactory(*a, **kw)
fac.nickname = 'foo'
fac.privKey = None
fac.port = None
fac.localPort = None
fac.options = {}
fac.sigType = None
fac.protocol = protoClass
fac.resultNotOK = Mock()
def raise_(reason):
raise reason.value
fac.connectionFailed = lambda reason: raise_(reason)
proto = fac.buildProtocol(None)
transport = proto_helpers.StringTransport()
transport.abortConnection = lambda: None
proto.makeConnection(transport)
return fac, proto
def test_initSendsHello(self):
fac, proto = self.makeProto()
self.assertSubstring('HELLO VERSION', str(proto.transport.value()))
def test_helloReturnsError(self):
fac, proto = self.makeProto()
proto.transport.clear()
proto.dataReceived(b'HELLO REPLY RESULT=I2P_ERROR MESSAGE="foo bar baz"\n')
fac.resultNotOK.assert_called_with('I2P_ERROR', 'foo bar baz')
def test_pingReceived(self):
fac, proto = self.makeProto()
self.addCleanup(proto.receiver.stopPinging)
proto.transport.clear()
# Enable keepalive
proto.receiver.currentRule = 'State_keepalive'
proto._parser._setupInterp()
proto.dataReceived(b'PING\n')
self.assertEquals(
b'PONG\n',
proto.transport.value())
def test_pingReceivedWithData(self):
fac, proto = self.makeProto()
self.addCleanup(proto.receiver.stopPinging)
proto.transport.clear()
# Enable keepalive
proto.receiver.currentRule = 'State_keepalive'
proto._parser._setupInterp()
proto.dataReceived(b'PING some random data\n')
self.assertEquals(
b'PONG some random data\n',
proto.transport.value())
def test_pingReceivedResetsTimeout(self):
fac, proto = self.makeProto()
self.addCleanup(proto.receiver.stopPinging)
proto.transport.clear()
# Enable keepalive
proto.receiver.currentRule = 'State_keepalive'
proto._parser._setupInterp()
proto.receiver._sendPing()
self.assertEquals(
'PING %s\n' % proto.receiver.lastPing,
proto.transport.value().decode('utf-8'))
self.assertTrue(proto.receiver.pingTimeout.active())
proto.transport.clear()
proto.dataReceived(b'PING\n')
self.assertEquals(
b'PONG\n',
proto.transport.value())
self.assertFalse(proto.receiver.pingTimeout.active())
def test_validPongResponseResetsTimeout(self):
fac, proto = self.makeProto()
self.addCleanup(proto.receiver.stopPinging)
proto.transport.clear()
# Enable keepalive
proto.receiver.currentRule = 'State_keepalive'
proto._parser._setupInterp()
proto.receiver._sendPing()
self.assertEquals(
'PING %s\n' % proto.receiver.lastPing,
proto.transport.value().decode('utf-8'))
self.assertTrue(proto.receiver.pingTimeout.active())
proto.transport.clear()
proto.dataReceived(('PONG %s\n' % proto.receiver.lastPing).encode('utf-8'))
self.assertFalse(proto.receiver.pingTimeout.active())
def test_invalidPongResponseDoesNotResetTimeout(self):
fac, proto = self.makeProto()
self.addCleanup(proto.receiver.stopPinging)
proto.transport.clear()
# Enable keepalive
proto.receiver.currentRule = 'State_keepalive'
proto._parser._setupInterp()
proto.receiver._sendPing()
self.assertEquals(
'PING %s\n' % proto.receiver.lastPing,
proto.transport.value().decode('utf-8'))
self.assertTrue(proto.receiver.pingTimeout.active())
proto.transport.clear()
proto.dataReceived(b'PONG not what was expected\n')
self.assertTrue(proto.receiver.pingTimeout.active())
class SAMFactoryTestMixin(object):
def setUp(self):
self.aborted = []
def makeProto(self, *a, **kw):
fac = self.factory(*a, **kw)
proto = fac.buildProtocol(None)
transport = proto_helpers.StringTransport()
transport.abortConnection = lambda: self.aborted.append(True)
proto.makeConnection(transport)
return fac, proto
def test_cancellation(self):
fac, proto = self.makeProto(*self.blankFactoryArgs)
fac.deferred.cancel()
self.assert_(self.aborted)
return self.assertFailure(fac.deferred, defer.CancelledError)
def test_cancellationBeforeFailure(self):
fac, proto = self.makeProto(*self.blankFactoryArgs)
fac.deferred.cancel()
proto.connectionLost(connectionLostFailure)
self.assert_(self.aborted)
return self.assertFailure(fac.deferred, defer.CancelledError)
def test_cancellationAfterFailure(self):
fac, proto = self.makeProto(*self.blankFactoryArgs)
proto.connectionLost(connectionLostFailure)
fac.deferred.cancel()
self.assertFalse(self.aborted)
return self.assertFailure(fac.deferred, ConnectionLost)
def test_clientConnectionFailed(self):
fac, proto = self.makeProto(*self.blankFactoryArgs)
fac.clientConnectionFailed(None, connectionRefusedFailure)
return self.assertFailure(fac.deferred, ConnectionRefusedError)
def test_resultNotOK(self):
fac, proto = self.makeProto(*self.blankFactoryArgs)
for result, error in list(c.samErrorMap.items()):
self.assertRaises(error, fac.resultNotOK, result, '')
|