#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Tests for protorpc.remote."""

__author__ = 'rafek@google.com (Rafe Kaplan)'


import sys
import types
import unittest
from wsgiref import headers

from protorpc import descriptor
from protorpc import message_types
from protorpc import messages
from protorpc import protobuf
from protorpc import protojson
from protorpc import remote
from protorpc import test_util
from protorpc import transport

import mox


class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
                          test_util.TestCase):

  MODULE = remote


class Request(messages.Message):
  """Test request message."""

  value = messages.StringField(1)


class Response(messages.Message):
  """Test response message."""

  value = messages.StringField(1)


class MyService(remote.Service):

  @remote.method(Request, Response)
  def remote_method(self, request):
    response = Response()
    response.value = request.value
    return response


class SimpleRequest(messages.Message):
  """Simple request message type used for tests."""

  param1 = messages.StringField(1)
  param2 = messages.StringField(2)


class SimpleResponse(messages.Message):
  """Simple response message type used for tests."""


class BasicService(remote.Service):
  """A basic service with decorated remote method."""

  def __init__(self):
    self.request_ids = []

  @remote.method(SimpleRequest, SimpleResponse)
  def remote_method(self, request):
    self.request_ids.append(id(request))
    return SimpleResponse()


class RpcErrorTest(test_util.TestCase):

  def testFromStatus(self):
    for state in remote.RpcState:
      exception = remote.RpcError.from_state
    self.assertEquals(remote.ServerError,
                      remote.RpcError.from_state('SERVER_ERROR'))


class ApplicationErrorTest(test_util.TestCase):

  def testErrorCode(self):
    self.assertEquals('blam',
                      remote.ApplicationError('an error', 'blam').error_name)

  def testStr(self):
    self.assertEquals('an error', str(remote.ApplicationError('an error', 1)))

  def testRepr(self):
    self.assertEquals("ApplicationError('an error', 1)",
                      repr(remote.ApplicationError('an error', 1)))

    self.assertEquals("ApplicationError('an error')",
                      repr(remote.ApplicationError('an error')))


class MethodTest(test_util.TestCase):
  """Test remote method decorator."""

  def testMethod(self):
    """Test use of remote decorator."""
    self.assertEquals(SimpleRequest,
                      BasicService.remote_method.remote.request_type)
    self.assertEquals(SimpleResponse,
                      BasicService.remote_method.remote.response_type)
    self.assertTrue(isinstance(BasicService.remote_method.remote.method,
                               types.FunctionType))

  def testMethodMessageResolution(self):
    """Test use of remote decorator to resolve message types by name."""
    class OtherService(remote.Service):

      @remote.method('SimpleRequest', 'SimpleResponse')
      def remote_method(self, request):
        pass

    self.assertEquals(SimpleRequest,
                      OtherService.remote_method.remote.request_type)
    self.assertEquals(SimpleResponse,
                      OtherService.remote_method.remote.response_type)

  def testMethodMessageResolution_NotFound(self):
    """Test failure to find message types."""
    class OtherService(remote.Service):

      @remote.method('NoSuchRequest', 'NoSuchResponse')
      def remote_method(self, request):
        pass

    self.assertRaisesWithRegexpMatch(
      messages.DefinitionNotFoundError,
      'Could not find definition for NoSuchRequest',
      getattr,
      OtherService.remote_method.remote,
      'request_type')

    self.assertRaisesWithRegexpMatch(
      messages.DefinitionNotFoundError,
      'Could not find definition for NoSuchResponse',
      getattr,
      OtherService.remote_method.remote,
      'response_type')

  def testInvocation(self):
    """Test that invocation passes request through properly."""
    service = BasicService()
    request = SimpleRequest()
    self.assertEquals(SimpleResponse(), service.remote_method(request))
    self.assertEquals([id(request)], service.request_ids)

  def testInvocation_WrongRequestType(self):
    """Wrong request type passed to remote method."""
    service = BasicService()

    self.assertRaises(remote.RequestError,
                      service.remote_method,
                      'wrong')

    self.assertRaises(remote.RequestError,
                      service.remote_method,
                      None)

    self.assertRaises(remote.RequestError,
                      service.remote_method,
                      SimpleResponse())

  def testInvocation_WrongResponseType(self):
    """Wrong response type returned from remote method."""

    class AnotherService(object):

      @remote.method(SimpleRequest, SimpleResponse)
      def remote_method(self, unused_request):
        return self.return_this

    service = AnotherService()

    service.return_this = 'wrong'
    self.assertRaises(remote.ServerError,
                      service.remote_method,
                      SimpleRequest())
    service.return_this = None
    self.assertRaises(remote.ServerError,
                      service.remote_method,
                      SimpleRequest())
    service.return_this = SimpleRequest()
    self.assertRaises(remote.ServerError,
                      service.remote_method,
                      SimpleRequest())

  def testBadRequestType(self):
    """Test bad request types used in remote definition."""

    for request_type in (None, 1020, messages.Message, str):

      def declare():
        class BadService(object):

          @remote.method(request_type, SimpleResponse)
          def remote_method(self, request):
            pass

      self.assertRaises(TypeError, declare)

  def testBadResponseType(self):
    """Test bad response types used in remote definition."""

    for response_type in (None, 1020, messages.Message, str):

      def declare():
        class BadService(object):

          @remote.method(SimpleRequest, response_type)
          def remote_method(self, request):
            pass

      self.assertRaises(TypeError, declare)


class GetRemoteMethodTest(test_util.TestCase):
  """Test for is_remote_method."""

  def testGetRemoteMethod(self):
    """Test valid remote method detection."""

    class Service(object):

      @remote.method(Request, Response)
      def remote_method(self, request):
        pass

    self.assertEquals(Service.remote_method.remote,
                      remote.get_remote_method_info(Service.remote_method))
    self.assertTrue(Service.remote_method.remote,
                    remote.get_remote_method_info(Service().remote_method))

  def testGetNotRemoteMethod(self):
    """Test positive result on a remote method."""

    class NotService(object):

      def not_remote_method(self, request):
        pass

    def fn(self):
      pass

    class NotReallyRemote(object):
      """Test negative result on many bad values for remote methods."""

      def not_really(self, request):
        pass

      not_really.remote = 'something else'

    for not_remote in [NotService.not_remote_method,
                       NotService().not_remote_method,
                       NotReallyRemote.not_really,
                       NotReallyRemote().not_really,
                       None,
                       1,
                       'a string',
                       fn]:
      self.assertEquals(None, remote.get_remote_method_info(not_remote))


class RequestStateTest(test_util.TestCase):
  """Test request state."""

  STATE_CLASS = remote.RequestState

  def testConstructor(self):
    """Test constructor."""
    state = self.STATE_CLASS(remote_host='remote-host',
                             remote_address='remote-address',
                             server_host='server-host',
                             server_port=10)
    self.assertEquals('remote-host', state.remote_host)
    self.assertEquals('remote-address', state.remote_address)
    self.assertEquals('server-host', state.server_host)
    self.assertEquals(10, state.server_port)

    state = self.STATE_CLASS()
    self.assertEquals(None, state.remote_host)
    self.assertEquals(None, state.remote_address)
    self.assertEquals(None, state.server_host)
    self.assertEquals(None, state.server_port)

  def testConstructorError(self):
    """Test unexpected keyword argument."""
    self.assertRaises(TypeError,
                      self.STATE_CLASS,
                      x=10)

  def testRepr(self):
    """Test string representation."""
    self.assertEquals('<%s>' % self.STATE_CLASS.__name__,
                      repr(self.STATE_CLASS()))
    self.assertEquals("<%s remote_host='abc'>" % self.STATE_CLASS.__name__,
                      repr(self.STATE_CLASS(remote_host='abc')))
    self.assertEquals("<%s remote_host='abc' "
                      "remote_address='def'>" % self.STATE_CLASS.__name__,
                      repr(self.STATE_CLASS(remote_host='abc',
                                               remote_address='def')))
    self.assertEquals("<%s remote_host='abc' "
                      "remote_address='def' "
                      "server_host='ghi'>" % self.STATE_CLASS.__name__,
                      repr(self.STATE_CLASS(remote_host='abc',
                                            remote_address='def',
                                            server_host='ghi')))
    self.assertEquals("<%s remote_host='abc' "
                      "remote_address='def' "
                      "server_host='ghi' "
                      'server_port=102>' % self.STATE_CLASS.__name__,
                      repr(self.STATE_CLASS(remote_host='abc',
                                            remote_address='def',
                                            server_host='ghi',
                                            server_port=102)))


class HttpRequestStateTest(RequestStateTest):

  STATE_CLASS = remote.HttpRequestState

  def testHttpMethod(self):
    state = remote.HttpRequestState(http_method='GET')
    self.assertEquals('GET', state.http_method)

  def testHttpMethod(self):
    state = remote.HttpRequestState(service_path='/bar')
    self.assertEquals('/bar', state.service_path)

  def testHeadersList(self):
    state = remote.HttpRequestState(
      headers=[('a', 'b'), ('c', 'd'), ('c', 'e')])

    self.assertEquals(['a', 'c', 'c'], state.headers.keys())
    self.assertEquals(['b'], state.headers.get_all('a'))
    self.assertEquals(['d', 'e'], state.headers.get_all('c'))

  def testHeadersDict(self):
    state = remote.HttpRequestState(headers={'a': 'b', 'c': ['d', 'e']})

    self.assertEquals(['a', 'c', 'c'], sorted(state.headers.keys()))
    self.assertEquals(['b'], state.headers.get_all('a'))
    self.assertEquals(['d', 'e'], state.headers.get_all('c'))

  def testRepr(self):
    super(HttpRequestStateTest, self).testRepr()

    self.assertEquals("<%s remote_host='abc' "
                      "remote_address='def' "
                      "server_host='ghi' "
                      'server_port=102 '
                      "http_method='POST' "
                      "service_path='/bar' "
                      "headers=[('a', 'b'), ('c', 'd')]>" %
                      self.STATE_CLASS.__name__,
                      repr(self.STATE_CLASS(remote_host='abc',
                                            remote_address='def',
                                            server_host='ghi',
                                            server_port=102,
                                            http_method='POST',
                                            service_path='/bar',
                                            headers={'a': 'b', 'c': 'd'},
                                            )))


class ServiceTest(test_util.TestCase):
  """Test Service class."""

  def testServiceBase_AllRemoteMethods(self):
    """Test that service base class has no remote methods."""
    self.assertEquals({}, remote.Service.all_remote_methods())

  def testAllRemoteMethods(self):
    """Test all_remote_methods with properly Service subclass."""
    self.assertEquals({'remote_method': MyService.remote_method},
                      MyService.all_remote_methods())

  def testAllRemoteMethods_SubClass(self):
    """Test all_remote_methods on a sub-class of a service."""
    class SubClass(MyService):

      @remote.method(Request, Response)
      def sub_class_method(self, request):
        pass

    self.assertEquals({'remote_method': SubClass.remote_method,
                       'sub_class_method': SubClass.sub_class_method,
                      },
                      SubClass.all_remote_methods())

  def testOverrideMethod(self):
    """Test that trying to override a remote method with remote decorator."""
    class SubClass(MyService):

      def remote_method(self, request):
        response = super(SubClass, self).remote_method(request)
        response.value = '(%s)' % response.value
        return response

    self.assertEquals({'remote_method': SubClass.remote_method,
                      },
                      SubClass.all_remote_methods())

    instance = SubClass()
    self.assertEquals('(Hello)',
                      instance.remote_method(Request(value='Hello')).value)
    self.assertEquals(Request, SubClass.remote_method.remote.request_type)
    self.assertEquals(Response, SubClass.remote_method.remote.response_type)

  def testOverrideMethodWithRemote(self):
    """Test trying to override a remote method with remote decorator."""
    def do_override():
      class SubClass(MyService):

        @remote.method(Request, Response)
        def remote_method(self, request):
          pass

    self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError,
                                     'Do not use method decorator when '
                                     'overloading remote method remote_method '
                                     'on service SubClass',
                                     do_override)

  def testOverrideMethodWithInvalidValue(self):
    """Test trying to override a remote method with remote decorator."""
    def do_override(bad_value):
      class SubClass(MyService):

        remote_method = bad_value

    for bad_value in [None, 1, 'string', {}]:
      self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError,
                                       'Must override remote_method in '
                                       'SubClass with a method',
                                       do_override, bad_value)

  def testCallingRemoteMethod(self):
    """Test invoking a remote method."""
    expected = Response()
    expected.value = 'what was passed in'

    request = Request()
    request.value = 'what was passed in'

    service = MyService()
    self.assertEquals(expected, service.remote_method(request))

  def testFactory(self):
    """Test using factory to pass in state."""
    class StatefulService(remote.Service):

      def __init__(self, a, b, c=None):
        self.a = a
        self.b = b
        self.c = c

    state = [1, 2, 3]

    factory = StatefulService.new_factory(1, state)

    module_name = ServiceTest.__module__
    pattern = ('Creates new instances of service StatefulService.\n\n'
               'Returns:\n'
               '  New instance of %s.StatefulService.' % module_name)
    self.assertEqual(pattern, factory.func_doc)
    self.assertEquals('StatefulService_service_factory', factory.func_name)
    self.assertEquals(StatefulService, factory.service_class)

    service = factory()
    self.assertEquals(1, service.a)
    self.assertEquals(id(state), id(service.b))
    self.assertEquals(None, service.c)

    factory = StatefulService.new_factory(2, b=3, c=4)
    service = factory()
    self.assertEquals(2, service.a)
    self.assertEquals(3, service.b)
    self.assertEquals(4, service.c)

  def testFactoryError(self):
    """Test misusing a factory."""
    # Passing positional argument that is not accepted by class.
    self.assertRaises(TypeError, remote.Service.new_factory(1))

    # Passing keyword argument that is not accepted by class.
    self.assertRaises(TypeError, remote.Service.new_factory(x=1))

    class StatefulService(remote.Service):

      def __init__(self, a):
        pass

    # Missing required parameter.
    self.assertRaises(TypeError, StatefulService.new_factory())

  def testDefinitionName(self):
    """Test getting service definition name."""
    class TheService(remote.Service):
      pass

    module_name = test_util.get_module_name(ServiceTest)
    self.assertEqual(TheService.definition_name(),
                     '%s.TheService' % module_name)
    self.assertTrue(TheService.outer_definition_name(),
                    module_name)
    self.assertTrue(TheService.definition_package(),
                    module_name)

  def testDefinitionNameWithPackage(self):
    """Test getting service definition name when package defined."""
    global package
    package = 'my.package'
    try:
      class TheService(remote.Service):
        pass

      self.assertEquals('my.package.TheService', TheService.definition_name())
      self.assertEquals('my.package', TheService.outer_definition_name())
      self.assertEquals('my.package', TheService.definition_package())
    finally:
      del package

  def testDefinitionNameWithNoModule(self):
    """Test getting service definition name when package defined."""
    module = sys.modules[__name__]
    try:
      del sys.modules[__name__]
      class TheService(remote.Service):
        pass

      self.assertEquals('TheService', TheService.definition_name())
      self.assertEquals(None, TheService.outer_definition_name())
      self.assertEquals(None, TheService.definition_package())
    finally:
      sys.modules[__name__] = module


class StubTest(test_util.TestCase):

  def setUp(self):
    self.mox = mox.Mox()
    self.transport = self.mox.CreateMockAnything()

  def testDefinitionName(self):
    self.assertEquals(BasicService.definition_name(),
                      BasicService.Stub.definition_name())
    self.assertEquals(BasicService.outer_definition_name(),
                      BasicService.Stub.outer_definition_name())
    self.assertEquals(BasicService.definition_package(),
                      BasicService.Stub.definition_package())

  def testRemoteMethods(self):
    self.assertEquals(BasicService.all_remote_methods(),
                      BasicService.Stub.all_remote_methods())

  def testSync_WithRequest(self):
    stub = BasicService.Stub(self.transport)

    request = SimpleRequest()
    request.param1 = 'val1'
    request.param2 = 'val2'
    response = SimpleResponse()

    rpc = transport.Rpc(request)
    rpc.set_response(response)
    self.transport.send_rpc(BasicService.remote_method.remote,
                            request).AndReturn(rpc)

    self.mox.ReplayAll()

    self.assertEquals(SimpleResponse(), stub.remote_method(request))

    self.mox.VerifyAll()

  def testSync_WithKwargs(self):
    stub = BasicService.Stub(self.transport)


    request = SimpleRequest()
    request.param1 = 'val1'
    request.param2 = 'val2'
    response = SimpleResponse()

    rpc = transport.Rpc(request)
    rpc.set_response(response)
    self.transport.send_rpc(BasicService.remote_method.remote,
                            request).AndReturn(rpc)

    self.mox.ReplayAll()

    self.assertEquals(SimpleResponse(), stub.remote_method(param1='val1',
                                                           param2='val2'))

    self.mox.VerifyAll()

  def testAsync_WithRequest(self):
    stub = BasicService.Stub(self.transport)

    request = SimpleRequest()
    request.param1 = 'val1'
    request.param2 = 'val2'
    response = SimpleResponse()

    rpc = transport.Rpc(request)

    self.transport.send_rpc(BasicService.remote_method.remote,
                            request).AndReturn(rpc)

    self.mox.ReplayAll()

    self.assertEquals(rpc, stub.async.remote_method(request))

    self.mox.VerifyAll()

  def testAsync_WithKwargs(self):
    stub = BasicService.Stub(self.transport)

    request = SimpleRequest()
    request.param1 = 'val1'
    request.param2 = 'val2'
    response = SimpleResponse()

    rpc = transport.Rpc(request)

    self.transport.send_rpc(BasicService.remote_method.remote,
                            request).AndReturn(rpc)

    self.mox.ReplayAll()

    self.assertEquals(rpc, stub.async.remote_method(param1='val1',
                                                    param2='val2'))

    self.mox.VerifyAll()

  def testAsync_WithRequestAndKwargs(self):
    stub = BasicService.Stub(self.transport)

    request = SimpleRequest()
    request.param1 = 'val1'
    request.param2 = 'val2'
    response = SimpleResponse()

    self.mox.ReplayAll()

    self.assertRaisesWithRegexpMatch(
      TypeError,
      r'May not provide both args and kwargs',
      stub.async.remote_method,
      request,
      param1='val1',
      param2='val2')

    self.mox.VerifyAll()

  def testAsync_WithTooManyPositionals(self):
    stub = BasicService.Stub(self.transport)

    request = SimpleRequest()
    request.param1 = 'val1'
    request.param2 = 'val2'
    response = SimpleResponse()

    self.mox.ReplayAll()

    self.assertRaisesWithRegexpMatch(
      TypeError,
      r'remote_method\(\) takes at most 2 positional arguments \(3 given\)',
      stub.async.remote_method,
      request, 'another value')

    self.mox.VerifyAll()


class IsErrorStatusTest(test_util.TestCase):

  def testIsError(self):
    for state in (s for s in remote.RpcState if s > remote.RpcState.RUNNING):
      status = remote.RpcStatus(state=state)
      self.assertTrue(remote.is_error_status(status))

  def testIsNotError(self):
    for state in (s for s in remote.RpcState if s <= remote.RpcState.RUNNING):
      status = remote.RpcStatus(state=state)
      self.assertFalse(remote.is_error_status(status))

  def testStateNone(self):
    self.assertRaises(messages.ValidationError,
                      remote.is_error_status, remote.RpcStatus())


class CheckRpcStatusTest(test_util.TestCase):

  def testStateNone(self):
    self.assertRaises(messages.ValidationError,
                      remote.check_rpc_status, remote.RpcStatus())

  def testNoError(self):
    for state in (remote.RpcState.OK, remote.RpcState.RUNNING):
      remote.check_rpc_status(remote.RpcStatus(state=state))

  def testErrorState(self):
    status = remote.RpcStatus(state=remote.RpcState.REQUEST_ERROR,
                              error_message='a request error')
    self.assertRaisesWithRegexpMatch(remote.RequestError,
                                     'a request error',
                                     remote.check_rpc_status, status)

  def testApplicationErrorState(self):
    status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR,
                              error_message='an application error',
                              error_name='blam')
    try:
      remote.check_rpc_status(status)
      self.fail('Should have raised application error.')
    except remote.ApplicationError as err:
      self.assertEquals('an application error', str(err))
      self.assertEquals('blam', err.error_name)


class ProtocolConfigTest(test_util.TestCase):

  def testConstructor(self):
    config = remote.ProtocolConfig(
      protojson,
      'proto1',
      'application/X-Json',
      iter(['text/Json', 'text/JavaScript']))
    self.assertEquals(protojson, config.protocol)
    self.assertEquals('proto1', config.name)
    self.assertEquals('application/x-json', config.default_content_type)
    self.assertEquals(('text/json', 'text/javascript'),
                      config.alternate_content_types)
    self.assertEquals(('application/x-json', 'text/json', 'text/javascript'),
                      config.content_types)

  def testConstructorDefaults(self):
    config = remote.ProtocolConfig(protojson, 'proto2')
    self.assertEquals(protojson, config.protocol)
    self.assertEquals('proto2', config.name)
    self.assertEquals('application/json', config.default_content_type)
    self.assertEquals(('application/x-javascript',
                       'text/javascript',
                       'text/x-javascript',
                       'text/x-json',
                       'text/json'),
                      config.alternate_content_types)
    self.assertEquals(('application/json',
                       'application/x-javascript',
                       'text/javascript',
                       'text/x-javascript',
                       'text/x-json',
                       'text/json'), config.content_types)

  def testEmptyAlternativeTypes(self):
    config = remote.ProtocolConfig(protojson, 'proto2',
                                   alternative_content_types=())
    self.assertEquals(protojson, config.protocol)
    self.assertEquals('proto2', config.name)
    self.assertEquals('application/json', config.default_content_type)
    self.assertEquals((), config.alternate_content_types)
    self.assertEquals(('application/json',), config.content_types)

  def testDuplicateContentTypes(self):
    self.assertRaises(remote.ServiceConfigurationError,
                      remote.ProtocolConfig,
                      protojson,
                      'json',
                      'text/plain',
                      ('text/plain',))

    self.assertRaises(remote.ServiceConfigurationError,
                      remote.ProtocolConfig,
                      protojson,
                      'json',
                      'text/plain',
                      ('text/html', 'text/html'))

  def testEncodeMessage(self):
    config = remote.ProtocolConfig(protojson, 'proto2')
    encoded_message = config.encode_message(
        remote.RpcStatus(state=remote.RpcState.SERVER_ERROR,
                         error_message='bad error'))

    # Convert back to a dictionary from JSON.
    dict_message = protojson.json.loads(encoded_message)
    self.assertEquals({'state': 'SERVER_ERROR', 'error_message': 'bad error'},
                      dict_message)

  def testDecodeMessage(self):
    config = remote.ProtocolConfig(protojson, 'proto2')
    self.assertEquals(
      remote.RpcStatus(state=remote.RpcState.SERVER_ERROR,
                       error_message="bad error"),
      config.decode_message(
        remote.RpcStatus,
        '{"state": "SERVER_ERROR", "error_message": "bad error"}'))


class ProtocolsTest(test_util.TestCase):

  def setUp(self):
    self.protocols = remote.Protocols()

  def testEmpty(self):
    self.assertEquals((), self.protocols.names)
    self.assertEquals((), self.protocols.content_types)

  def testAddProtocolAllDefaults(self):
    self.protocols.add_protocol(protojson, 'json')
    self.assertEquals(('json',), self.protocols.names)
    self.assertEquals(('application/json',
                       'application/x-javascript',
                       'text/javascript',
                       'text/json',
                       'text/x-javascript',
                       'text/x-json'),
                      self.protocols.content_types)

  def testAddProtocolNoDefaultAlternatives(self):
    class Protocol(object):
      CONTENT_TYPE = 'text/plain'

    self.protocols.add_protocol(Protocol, 'text')
    self.assertEquals(('text',), self.protocols.names)
    self.assertEquals(('text/plain',), self.protocols.content_types)

  def testAddProtocolOverrideDefaults(self):
    self.protocols.add_protocol(protojson, 'json',
                                default_content_type='text/blar',
                                alternative_content_types=('text/blam',
                                                           'text/blim'))
    self.assertEquals(('json',), self.protocols.names)
    self.assertEquals(('text/blam', 'text/blar', 'text/blim'),
                      self.protocols.content_types)

  def testLookupByName(self):
    self.protocols.add_protocol(protojson, 'json')
    self.protocols.add_protocol(protojson, 'json2',
                                default_content_type='text/plain',
                                alternative_content_types=())

    self.assertEquals('json', self.protocols.lookup_by_name('JsOn').name)
    self.assertEquals('json2', self.protocols.lookup_by_name('Json2').name)

  def testLookupByContentType(self):
    self.protocols.add_protocol(protojson, 'json')
    self.protocols.add_protocol(protojson, 'json2',
                                default_content_type='text/plain',
                                alternative_content_types=())

    self.assertEquals(
      'json',
      self.protocols.lookup_by_content_type('AppliCation/Json').name)

    self.assertEquals(
      'json',
      self.protocols.lookup_by_content_type('text/x-Json').name)

    self.assertEquals(
      'json2',
      self.protocols.lookup_by_content_type('text/Plain').name)

  def testNewDefault(self):
    protocols = remote.Protocols.new_default()
    self.assertEquals(('protobuf', 'protojson'), protocols.names)

    protobuf_protocol = protocols.lookup_by_name('protobuf')
    self.assertEquals(protobuf, protobuf_protocol.protocol)

    protojson_protocol = protocols.lookup_by_name('protojson')
    self.assertEquals(protojson.ProtoJson.get_default(),
                      protojson_protocol.protocol)

  def testGetDefaultProtocols(self):
    protocols = remote.Protocols.get_default()
    self.assertEquals(('protobuf', 'protojson'), protocols.names)

    protobuf_protocol = protocols.lookup_by_name('protobuf')
    self.assertEquals(protobuf, protobuf_protocol.protocol)

    protojson_protocol = protocols.lookup_by_name('protojson')
    self.assertEquals(protojson.ProtoJson.get_default(),
                      protojson_protocol.protocol)

    self.assertTrue(protocols is remote.Protocols.get_default())

  def testSetDefaultProtocols(self):
    protocols = remote.Protocols()
    remote.Protocols.set_default(protocols)
    self.assertTrue(protocols is remote.Protocols.get_default())

  def testSetDefaultWithoutProtocols(self):
    self.assertRaises(TypeError, remote.Protocols.set_default, None)
    self.assertRaises(TypeError, remote.Protocols.set_default, 'hi protocols')
    self.assertRaises(TypeError, remote.Protocols.set_default, {})


def main():
  unittest.main()


if __name__ == '__main__':
  main()
