1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
|
"""Tests for consumer handling of association responses
This duplicates some things that are covered by test_consumer, but
this works for now.
"""
from openid.test.test_consumer import CatchLogs
from openid.message import Message, OPENID2_NS, OPENID_NS
from openid.server.server import DiffieHellmanSHA1ServerSession
from openid.consumer.consumer import GenericConsumer, ProtocolError
from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_1_1_TYPE,\
OPENID_2_0_TYPE
from openid.store import memstore
import unittest
# Some values we can use for convenience (see mkAssocResponse)
association_response_values = {
'expires_in': '1000',
'assoc_handle': 'a handle',
'assoc_type': 'a type',
'session_type': 'a session type',
'ns': OPENID2_NS,
}
def mkAssocResponse(*keys):
"""Build an association response message that contains the
specified subset of keys. The values come from
`association_response_values`.
This is useful for testing for missing keys and other times that
we don't care what the values are."""
args = dict([(key, association_response_values[key]) for key in keys])
return Message.fromOpenIDArgs(args)
class BaseAssocTest(CatchLogs, unittest.TestCase):
def setUp(self):
CatchLogs.setUp(self)
self.store = memstore.MemoryStore()
self.consumer = GenericConsumer(self.store)
self.endpoint = OpenIDServiceEndpoint()
def failUnlessProtocolError(self, str_prefix, func, *args, **kwargs):
try:
result = func(*args, **kwargs)
except ProtocolError as e:
e_arg = e.args[0]
message = 'Expected prefix %r, got %r' % (str_prefix, e_arg)
self.assertTrue(e_arg.startswith(str_prefix), message)
else:
self.fail('Expected ProtocolError, got %r' % (result,))
def mkExtractAssocMissingTest(keys):
"""Factory function for creating test methods for generating
missing field tests.
Make a test that ensures that an association response that
is missing required fields will short-circuit return None.
According to 'Association Session Response' subsection 'Common
Response Parameters', the following fields are required for OpenID
2.0:
* ns
* session_type
* assoc_handle
* assoc_type
* expires_in
If 'ns' is missing, it will fall back to OpenID 1 checking. In
OpenID 1, everything except 'session_type' and 'ns' are required.
"""
def test(self):
msg = mkAssocResponse(*keys)
self.assertRaises(KeyError,
self.consumer._extractAssociation, msg, None)
return test
class TestExtractAssociationMissingFieldsOpenID2(BaseAssocTest):
"""Test for returning an error upon missing fields in association
responses for OpenID 2"""
test_noFields_openid2 = mkExtractAssocMissingTest(['ns'])
test_missingExpires_openid2 = mkExtractAssocMissingTest(
['assoc_handle', 'assoc_type', 'session_type', 'ns'])
test_missingHandle_openid2 = mkExtractAssocMissingTest(
['expires_in', 'assoc_type', 'session_type', 'ns'])
test_missingAssocType_openid2 = mkExtractAssocMissingTest(
['expires_in', 'assoc_handle', 'session_type', 'ns'])
test_missingSessionType_openid2 = mkExtractAssocMissingTest(
['expires_in', 'assoc_handle', 'assoc_type', 'ns'])
class TestExtractAssociationMissingFieldsOpenID1(BaseAssocTest):
"""Test for returning an error upon missing fields in association
responses for OpenID 2"""
test_noFields_openid1 = mkExtractAssocMissingTest([])
test_missingExpires_openid1 = mkExtractAssocMissingTest(
['assoc_handle', 'assoc_type'])
test_missingHandle_openid1 = mkExtractAssocMissingTest(
['expires_in', 'assoc_type'])
test_missingAssocType_openid1 = mkExtractAssocMissingTest(
['expires_in', 'assoc_handle'])
class DummyAssocationSession(object):
def __init__(self, session_type, allowed_assoc_types=()):
self.session_type = session_type
self.allowed_assoc_types = allowed_assoc_types
class ExtractAssociationSessionTypeMismatch(BaseAssocTest):
def mkTest(requested_session_type, response_session_type, openid1=False):
def test(self):
assoc_session = DummyAssocationSession(requested_session_type)
keys = list(association_response_values.keys())
if openid1:
keys.remove('ns')
msg = mkAssocResponse(*keys)
msg.setArg(OPENID_NS, 'session_type', response_session_type)
self.failUnlessProtocolError('Session type mismatch',
self.consumer._extractAssociation, msg, assoc_session)
return test
test_typeMismatchNoEncBlank_openid2 = mkTest(
requested_session_type='no-encryption',
response_session_type='',
)
test_typeMismatchDHSHA1NoEnc_openid2 = mkTest(
requested_session_type='DH-SHA1',
response_session_type='no-encryption',
)
test_typeMismatchDHSHA256NoEnc_openid2 = mkTest(
requested_session_type='DH-SHA256',
response_session_type='no-encryption',
)
test_typeMismatchNoEncDHSHA1_openid2 = mkTest(
requested_session_type='no-encryption',
response_session_type='DH-SHA1',
)
test_typeMismatchDHSHA1NoEnc_openid1 = mkTest(
requested_session_type='DH-SHA1',
response_session_type='DH-SHA256',
openid1=True,
)
test_typeMismatchDHSHA256NoEnc_openid1 = mkTest(
requested_session_type='DH-SHA256',
response_session_type='DH-SHA1',
openid1=True,
)
test_typeMismatchNoEncDHSHA1_openid1 = mkTest(
requested_session_type='no-encryption',
response_session_type='DH-SHA1',
openid1=True,
)
class TestOpenID1AssociationResponseSessionType(BaseAssocTest):
def mkTest(expected_session_type, session_type_value):
"""Return a test method that will check what session type will
be used if the OpenID 1 response to an associate call sets the
'session_type' field to `session_type_value`
"""
def test(self):
self._doTest(expected_session_type, session_type_value)
self.assertEqual(0, len(self.messages))
return test
def _doTest(self, expected_session_type, session_type_value):
# Create a Message with just 'session_type' in it, since
# that's all this function will use. 'session_type' may be
# absent if it's set to None.
args = {}
if session_type_value is not None:
args['session_type'] = session_type_value
message = Message.fromOpenIDArgs(args)
self.assertTrue(message.isOpenID1())
actual_session_type = self.consumer._getOpenID1SessionType(message)
error_message = ('Returned sesion type parameter %r was expected '
'to yield session type %r, but yielded %r' %
(session_type_value, expected_session_type,
actual_session_type))
self.assertEqual(
expected_session_type, actual_session_type, error_message)
test_none = mkTest(
session_type_value=None,
expected_session_type='no-encryption',
)
test_empty = mkTest(
session_type_value='',
expected_session_type='no-encryption',
)
# This one's different because it expects log messages
def test_explicitNoEncryption(self):
self._doTest(
session_type_value='no-encryption',
expected_session_type='no-encryption',
)
self.assertEqual(1, len(self.messages))
log_msg = self.messages[0]
self.assertEqual(log_msg['levelname'], 'WARNING')
self.assertTrue(log_msg['msg'].startswith(
'OpenID server sent "no-encryption"'))
test_dhSHA1 = mkTest(
session_type_value='DH-SHA1',
expected_session_type='DH-SHA1',
)
# DH-SHA256 is not a valid session type for OpenID1, but this
# function does not test that. This is mostly just to make sure
# that it will pass-through stuff that is not explicitly handled,
# so it will get handled the same way as it is handled for OpenID
# 2
test_dhSHA256 = mkTest(
session_type_value='DH-SHA256',
expected_session_type='DH-SHA256',
)
class DummyAssociationSession(object):
secret = b"shh! don't tell!" # association secrets are bytes
extract_secret_called = False
session_type = None
allowed_assoc_types = None
def extractSecret(self, message):
self.extract_secret_called = True
return self.secret
class TestInvalidFields(BaseAssocTest):
def setUp(self):
BaseAssocTest.setUp(self)
self.session_type = 'testing-session'
# This must something that works for Association.fromExpiresIn
self.assoc_type = 'HMAC-SHA1'
self.assoc_handle = 'testing-assoc-handle'
# These arguments should all be valid
self.assoc_response = Message.fromOpenIDArgs({
'expires_in': '1000',
'assoc_handle': self.assoc_handle,
'assoc_type': self.assoc_type,
'session_type': self.session_type,
'ns': OPENID2_NS,
})
self.assoc_session = DummyAssociationSession()
# Make the session for the response's session type
self.assoc_session.session_type = self.session_type
self.assoc_session.allowed_assoc_types = [self.assoc_type]
def test_worksWithGoodFields(self):
"""Handle a full successful association response"""
assoc = self.consumer._extractAssociation(
self.assoc_response, self.assoc_session)
self.assertTrue(self.assoc_session.extract_secret_called)
self.assertEqual(self.assoc_session.secret, assoc.secret)
self.assertEqual(1000, assoc.lifetime)
self.assertEqual(self.assoc_handle, assoc.handle)
self.assertEqual(self.assoc_type, assoc.assoc_type)
def test_badAssocType(self):
# Make sure that the assoc type in the response is not valid
# for the given session.
self.assoc_session.allowed_assoc_types = []
self.failUnlessProtocolError('Unsupported assoc_type for session',
self.consumer._extractAssociation,
self.assoc_response, self.assoc_session)
def test_badExpiresIn(self):
# Invalid value for expires_in should cause failure
self.assoc_response.setArg(OPENID_NS, 'expires_in', 'forever')
self.failUnlessProtocolError('Invalid expires_in',
self.consumer._extractAssociation,
self.assoc_response, self.assoc_session)
# XXX: This is what causes most of the imports in this file. It is
# sort of a unit test and sort of a functional test. I'm not terribly
# fond of it.
class TestExtractAssociationDiffieHellman(BaseAssocTest):
secret = b'x' * 20
def _setUpDH(self):
sess, message = self.consumer._createAssociateRequest(
self.endpoint, 'HMAC-SHA1', 'DH-SHA1')
# XXX: this is testing _createAssociateRequest
self.assertEqual(self.endpoint.compatibilityMode(),
message.isOpenID1())
server_sess = DiffieHellmanSHA1ServerSession.fromMessage(message)
server_resp = server_sess.answer(self.secret)
server_resp['assoc_type'] = 'HMAC-SHA1'
server_resp['assoc_handle'] = 'handle'
server_resp['expires_in'] = '1000'
server_resp['session_type'] = 'DH-SHA1'
return sess, Message.fromOpenIDArgs(server_resp)
def test_success(self):
sess, server_resp = self._setUpDH()
ret = self.consumer._extractAssociation(server_resp, sess)
self.assertFalse(ret is None)
self.assertEqual(ret.assoc_type, 'HMAC-SHA1')
self.assertEqual(ret.secret, self.secret)
self.assertEqual(ret.handle, 'handle')
self.assertEqual(ret.lifetime, 1000)
def test_openid2success(self):
# Use openid 2 type in endpoint so _setUpDH checks
# compatibility mode state properly
self.endpoint.type_uris = [OPENID_2_0_TYPE, OPENID_1_1_TYPE]
self.test_success()
def test_badDHValues(self):
sess, server_resp = self._setUpDH()
server_resp.setArg(OPENID_NS, 'enc_mac_key', '\x00\x00\x00')
self.failUnlessProtocolError('Malformed response for',
self.consumer._extractAssociation, server_resp, sess)
|