#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Tests for protorpc.stub."""

__author__ = 'rafek@google.com (Rafe Kaplan)'


import new
import StringIO
import sys
import unittest

from protorpc import definition
from protorpc import descriptor
from protorpc import message_types
from protorpc import messages
from protorpc import protobuf
from protorpc import remote
from protorpc import test_util

import mox


class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
                          test_util.TestCase):

  MODULE = definition


class DefineEnumTest(test_util.TestCase):
  """Test for define_enum."""

  def testDefineEnum_Empty(self):
    """Test defining an empty enum."""
    enum_descriptor = descriptor.EnumDescriptor()
    enum_descriptor.name = 'Empty'

    enum_class = definition.define_enum(enum_descriptor, 'whatever')

    self.assertEquals('Empty', enum_class.__name__)
    self.assertEquals('whatever', enum_class.__module__)

    self.assertEquals(enum_descriptor, descriptor.describe_enum(enum_class))

  def testDefineEnum(self):
    """Test defining an enum."""
    red = descriptor.EnumValueDescriptor()
    green = descriptor.EnumValueDescriptor()
    blue = descriptor.EnumValueDescriptor()

    red.name = 'RED'
    red.number = 1
    green.name = 'GREEN'
    green.number = 2
    blue.name = 'BLUE'
    blue.number = 3

    enum_descriptor = descriptor.EnumDescriptor()
    enum_descriptor.name = 'Colors'
    enum_descriptor.values = [red, green, blue]

    enum_class = definition.define_enum(enum_descriptor, 'whatever')

    self.assertEquals('Colors', enum_class.__name__)
    self.assertEquals('whatever', enum_class.__module__)

    self.assertEquals(enum_descriptor, descriptor.describe_enum(enum_class))


class DefineFieldTest(test_util.TestCase):
  """Test for define_field."""

  def testDefineField_Optional(self):
    """Test defining an optional field instance from a method descriptor."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_field'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT32
    field_descriptor.label = descriptor.FieldDescriptor.Label.OPTIONAL

    field = definition.define_field(field_descriptor)

    # Name will not be set from the original descriptor.
    self.assertFalse(hasattr(field, 'name'))

    self.assertTrue(isinstance(field, messages.IntegerField))
    self.assertEquals(1, field.number)
    self.assertEquals(descriptor.FieldDescriptor.Variant.INT32, field.variant)
    self.assertFalse(field.required)
    self.assertFalse(field.repeated)

  def testDefineField_Required(self):
    """Test defining a required field instance from a method descriptor."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_field'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.STRING
    field_descriptor.label = descriptor.FieldDescriptor.Label.REQUIRED

    field = definition.define_field(field_descriptor)

    # Name will not be set from the original descriptor.
    self.assertFalse(hasattr(field, 'name'))

    self.assertTrue(isinstance(field, messages.StringField))
    self.assertEquals(1, field.number)
    self.assertEquals(descriptor.FieldDescriptor.Variant.STRING, field.variant)
    self.assertTrue(field.required)
    self.assertFalse(field.repeated)

  def testDefineField_Repeated(self):
    """Test defining a repeated field instance from a method descriptor."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_field'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.DOUBLE
    field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED

    field = definition.define_field(field_descriptor)

    # Name will not be set from the original descriptor.
    self.assertFalse(hasattr(field, 'name'))

    self.assertTrue(isinstance(field, messages.FloatField))
    self.assertEquals(1, field.number)
    self.assertEquals(descriptor.FieldDescriptor.Variant.DOUBLE, field.variant)
    self.assertFalse(field.required)
    self.assertTrue(field.repeated)

  def testDefineField_Message(self):
    """Test defining a message field."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_field'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.MESSAGE
    field_descriptor.type_name = 'something.yet.to.be.Defined'
    field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED

    field = definition.define_field(field_descriptor)

    # Name will not be set from the original descriptor.
    self.assertFalse(hasattr(field, 'name'))

    self.assertTrue(isinstance(field, messages.MessageField))
    self.assertEquals(1, field.number)
    self.assertEquals(descriptor.FieldDescriptor.Variant.MESSAGE, field.variant)
    self.assertFalse(field.required)
    self.assertTrue(field.repeated)
    self.assertRaisesWithRegexpMatch(messages.DefinitionNotFoundError,
                                     'Could not find definition for '
                                     'something.yet.to.be.Defined',
                                     getattr, field, 'type')

  def testDefineField_DateTime(self):
    """Test defining a date time field."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_timestamp'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.MESSAGE
    field_descriptor.type_name = 'protorpc.message_types.DateTimeMessage'
    field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED

    field = definition.define_field(field_descriptor)

    # Name will not be set from the original descriptor.
    self.assertFalse(hasattr(field, 'name'))

    self.assertTrue(isinstance(field, message_types.DateTimeField))
    self.assertEquals(1, field.number)
    self.assertEquals(descriptor.FieldDescriptor.Variant.MESSAGE, field.variant)
    self.assertFalse(field.required)
    self.assertTrue(field.repeated)

  def testDefineField_Enum(self):
    """Test defining an enum field."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_field'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.ENUM
    field_descriptor.type_name = 'something.yet.to.be.Defined'
    field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED

    field = definition.define_field(field_descriptor)

    # Name will not be set from the original descriptor.
    self.assertFalse(hasattr(field, 'name'))

    self.assertTrue(isinstance(field, messages.EnumField))
    self.assertEquals(1, field.number)
    self.assertEquals(descriptor.FieldDescriptor.Variant.ENUM, field.variant)
    self.assertFalse(field.required)
    self.assertTrue(field.repeated)
    self.assertRaisesWithRegexpMatch(messages.DefinitionNotFoundError,
                                     'Could not find definition for '
                                     'something.yet.to.be.Defined',
                                     getattr, field, 'type')

  def testDefineField_Default_Bool(self):
    """Test defining a default value for a bool."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_field'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.BOOL
    field_descriptor.default_value = u'true'

    field = definition.define_field(field_descriptor)

    # Name will not be set from the original descriptor.
    self.assertFalse(hasattr(field, 'name'))

    self.assertTrue(isinstance(field, messages.BooleanField))
    self.assertEquals(1, field.number)
    self.assertEquals(descriptor.FieldDescriptor.Variant.BOOL, field.variant)
    self.assertFalse(field.required)
    self.assertFalse(field.repeated)
    self.assertEqual(field.default, True)

    field_descriptor.default_value = u'false'

    field = definition.define_field(field_descriptor)

    self.assertEqual(field.default, False)

  def testDefineField_Default_Float(self):
    """Test defining a default value for a float."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_field'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.FLOAT
    field_descriptor.default_value = u'34.567'

    field = definition.define_field(field_descriptor)

    # Name will not be set from the original descriptor.
    self.assertFalse(hasattr(field, 'name'))

    self.assertTrue(isinstance(field, messages.FloatField))
    self.assertEquals(1, field.number)
    self.assertEquals(descriptor.FieldDescriptor.Variant.FLOAT, field.variant)
    self.assertFalse(field.required)
    self.assertFalse(field.repeated)
    self.assertEqual(field.default, 34.567)

  def testDefineField_Default_Int(self):
    """Test defining a default value for an int."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_field'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT64
    field_descriptor.default_value = u'34'

    field = definition.define_field(field_descriptor)

    # Name will not be set from the original descriptor.
    self.assertFalse(hasattr(field, 'name'))

    self.assertTrue(isinstance(field, messages.IntegerField))
    self.assertEquals(1, field.number)
    self.assertEquals(descriptor.FieldDescriptor.Variant.INT64, field.variant)
    self.assertFalse(field.required)
    self.assertFalse(field.repeated)
    self.assertEqual(field.default, 34)

  def testDefineField_Default_Str(self):
    """Test defining a default value for a str."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_field'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.STRING
    field_descriptor.default_value = u'Test'

    field = definition.define_field(field_descriptor)

    # Name will not be set from the original descriptor.
    self.assertFalse(hasattr(field, 'name'))

    self.assertTrue(isinstance(field, messages.StringField))
    self.assertEquals(1, field.number)
    self.assertEquals(descriptor.FieldDescriptor.Variant.STRING, field.variant)
    self.assertFalse(field.required)
    self.assertFalse(field.repeated)
    self.assertEqual(field.default, u'Test')

  def testDefineField_Default_Invalid(self):
    """Test defining a default value that is not valid."""
    field_descriptor = descriptor.FieldDescriptor()

    field_descriptor.name = 'a_field'
    field_descriptor.number = 1
    field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT64
    field_descriptor.default_value = u'Test'

    # Verify that the string is passed to the Constructor.
    mock = mox.Mox()
    mock.StubOutWithMock(messages.IntegerField, '__init__')
    messages.IntegerField.__init__(
        default=u'Test',
        number=1,
        variant=messages.Variant.INT64
        ).AndRaise(messages.InvalidDefaultError)

    mock.ReplayAll()
    with self.assertRaises(messages.InvalidDefaultError):
      _ = definition.define_field(field_descriptor)
    mock.VerifyAll()

    mock.ResetAll()
    mock.UnsetStubs()


class DefineMessageTest(test_util.TestCase):
  """Test for define_message."""

  def testDefineMessageEmpty(self):
    """Test definition a message with no fields or enums."""

    class AMessage(messages.Message):
      pass

    message_descriptor = descriptor.describe_message(AMessage)

    message_class = definition.define_message(message_descriptor, '__main__')

    self.assertEquals('AMessage', message_class.__name__)
    self.assertEquals('__main__', message_class.__module__)

    self.assertEquals(message_descriptor,
                      descriptor.describe_message(message_class))

  def testDefineMessageEnumOnly(self):
    """Test definition a message with only enums."""

    class AMessage(messages.Message):
      class NestedEnum(messages.Enum):
        pass

    message_descriptor = descriptor.describe_message(AMessage)

    message_class = definition.define_message(message_descriptor, '__main__')

    self.assertEquals('AMessage', message_class.__name__)
    self.assertEquals('__main__', message_class.__module__)

    self.assertEquals(message_descriptor,
                      descriptor.describe_message(message_class))

  def testDefineMessageFieldsOnly(self):
    """Test definition a message with only fields."""

    class AMessage(messages.Message):

      field1 = messages.IntegerField(1)
      field2 = messages.StringField(2)

    message_descriptor = descriptor.describe_message(AMessage)

    message_class = definition.define_message(message_descriptor, '__main__')

    self.assertEquals('AMessage', message_class.__name__)
    self.assertEquals('__main__', message_class.__module__)

    self.assertEquals(message_descriptor,
                      descriptor.describe_message(message_class))

  def testDefineMessage(self):
    """Test defining Message class from descriptor."""

    class AMessage(messages.Message):
      class NestedEnum(messages.Enum):
        pass

      field1 = messages.IntegerField(1)
      field2 = messages.StringField(2)

    message_descriptor = descriptor.describe_message(AMessage)

    message_class = definition.define_message(message_descriptor, '__main__')

    self.assertEquals('AMessage', message_class.__name__)
    self.assertEquals('__main__', message_class.__module__)

    self.assertEquals(message_descriptor,
                      descriptor.describe_message(message_class))


class DefineServiceTest(test_util.TestCase):
  """Test service proxy definition."""

  def setUp(self):
    """Set up mock and request classes."""
    self.module = new.module('stocks')

    class GetQuoteRequest(messages.Message):
      __module__ = 'stocks'

      symbols = messages.StringField(1, repeated=True)

    class GetQuoteResponse(messages.Message):
      __module__ = 'stocks'

      prices = messages.IntegerField(1, repeated=True)

    self.module.GetQuoteRequest = GetQuoteRequest
    self.module.GetQuoteResponse = GetQuoteResponse

  def testDefineService(self):
    """Test service definition from descriptor."""
    method_descriptor = descriptor.MethodDescriptor()
    method_descriptor.name = 'get_quote'
    method_descriptor.request_type = 'GetQuoteRequest'
    method_descriptor.response_type = 'GetQuoteResponse'

    service_descriptor = descriptor.ServiceDescriptor()
    service_descriptor.name = 'Stocks'
    service_descriptor.methods = [method_descriptor]

    StockService = definition.define_service(service_descriptor, self.module)

    self.assertTrue(issubclass(StockService, remote.Service))
    self.assertTrue(issubclass(StockService.Stub, remote.StubBase))

    request = self.module.GetQuoteRequest()
    service = StockService()
    self.assertRaises(NotImplementedError,
                      service.get_quote, request)

    self.assertEquals(self.module.GetQuoteRequest,
                      service.get_quote.remote.request_type)
    self.assertEquals(self.module.GetQuoteResponse,
                      service.get_quote.remote.response_type)


class ModuleTest(test_util.TestCase):
  """Test for module creation and importation functions."""

  def MakeFileDescriptor(self, package):
    """Helper method to construct FileDescriptors.

    Creates FileDescriptor with a MessageDescriptor and an EnumDescriptor.

    Args:
      package: Package name to give new file descriptors.

    Returns:
      New FileDescriptor instance.
    """
    enum_descriptor = descriptor.EnumDescriptor()
    enum_descriptor.name = u'MyEnum'

    message_descriptor = descriptor.MessageDescriptor()
    message_descriptor.name = u'MyMessage'

    service_descriptor = descriptor.ServiceDescriptor()
    service_descriptor.name = u'MyService'

    file_descriptor = descriptor.FileDescriptor()
    file_descriptor.package = package
    file_descriptor.enum_types = [enum_descriptor]
    file_descriptor.message_types = [message_descriptor]
    file_descriptor.service_types = [service_descriptor]

    return file_descriptor

  def testDefineModule(self):
    """Test define_module function."""
    file_descriptor = self.MakeFileDescriptor('my.package')

    module = definition.define_file(file_descriptor)

    self.assertEquals('my.package', module.__name__)
    self.assertEquals('my.package', module.MyEnum.__module__)
    self.assertEquals('my.package', module.MyMessage.__module__)
    self.assertEquals('my.package', module.MyService.__module__)

    self.assertEquals(file_descriptor, descriptor.describe_file(module))

  def testDefineModule_ReuseModule(self):
    """Test updating module with additional definitions."""
    file_descriptor = self.MakeFileDescriptor('my.package')

    module = new.module('override')
    self.assertEquals(module, definition.define_file(file_descriptor, module))

    self.assertEquals('override', module.MyEnum.__module__)
    self.assertEquals('override', module.MyMessage.__module__)
    self.assertEquals('override', module.MyService.__module__)

    # One thing is different between original descriptor and new.
    file_descriptor.package = 'override'
    self.assertEquals(file_descriptor, descriptor.describe_file(module))

  def testImportFile(self):
    """Test importing FileDescriptor in to module space."""
    modules = {}
    file_descriptor = self.MakeFileDescriptor('standalone')
    definition.import_file(file_descriptor, modules=modules)
    self.assertEquals(file_descriptor,
                      descriptor.describe_file(modules['standalone']))

  def testImportFile_InToExisting(self):
    """Test importing FileDescriptor in to existing module."""
    module = new.module('standalone')
    modules = {'standalone': module}
    file_descriptor = self.MakeFileDescriptor('standalone')
    definition.import_file(file_descriptor, modules=modules)
    self.assertEquals(module, modules['standalone'])
    self.assertEquals(file_descriptor,
                      descriptor.describe_file(modules['standalone']))

  def testImportFile_InToGlobalModules(self):
    """Test importing FileDescriptor in to global modules."""
    original_modules = sys.modules
    try:
      sys.modules = dict(sys.modules)
      if 'standalone' in sys.modules:
        del sys.modules['standalone']
      file_descriptor = self.MakeFileDescriptor('standalone')
      definition.import_file(file_descriptor)
      self.assertEquals(file_descriptor,
                        descriptor.describe_file(sys.modules['standalone']))
    finally:
      sys.modules = original_modules

  def testImportFile_Nested(self):
    """Test importing FileDescriptor in to existing nested module."""
    modules = {}
    file_descriptor = self.MakeFileDescriptor('root.nested')
    definition.import_file(file_descriptor, modules=modules)
    self.assertEquals(modules['root'].nested, modules['root.nested'])
    self.assertEquals(file_descriptor,
                      descriptor.describe_file(modules['root.nested']))

  def testImportFile_NoPackage(self):
    """Test importing FileDescriptor with no package."""
    file_descriptor = self.MakeFileDescriptor('does not matter')
    file_descriptor.reset('package')
    self.assertRaisesWithRegexpMatch(ValueError,
                                     'File descriptor must have package name',
                                     definition.import_file,
                                     file_descriptor)

  def testImportFileSet(self):
    """Test importing a whole file set."""
    file_set = descriptor.FileSet()
    file_set.files = [self.MakeFileDescriptor(u'standalone'),
                      self.MakeFileDescriptor(u'root.nested'),
                      self.MakeFileDescriptor(u'root.nested.nested'),
                     ]

    root = new.module('root')
    nested = new.module('root.nested')
    root.nested = nested
    modules = {
        'root': root,
        'root.nested': nested,
    }

    definition.import_file_set(file_set, modules=modules)

    self.assertEquals(root, modules['root'])
    self.assertEquals(nested, modules['root.nested'])
    self.assertEquals(nested.nested, modules['root.nested.nested'])

    self.assertEquals(file_set,
                      descriptor.describe_file_set(
                          [modules['standalone'],
                           modules['root.nested'],
                           modules['root.nested.nested'],
                          ]))

  def testImportFileSetFromFile(self):
    """Test importing a whole file set from a file."""
    file_set = descriptor.FileSet()
    file_set.files = [self.MakeFileDescriptor(u'standalone'),
                      self.MakeFileDescriptor(u'root.nested'),
                      self.MakeFileDescriptor(u'root.nested.nested'),
                     ]

    stream = StringIO.StringIO(protobuf.encode_message(file_set))

    self.mox = mox.Mox()
    opener = self.mox.CreateMockAnything()
    opener('my-file.dat', 'rb').AndReturn(stream)

    self.mox.ReplayAll()

    modules = {}
    definition.import_file_set('my-file.dat', modules=modules, _open=opener)

    self.assertEquals(file_set,
                      descriptor.describe_file_set(
                          [modules['standalone'],
                           modules['root.nested'],
                           modules['root.nested.nested'],
                          ]))

  def testImportBuiltInProtorpcClasses(self):
    """Test that built in Protorpc classes are skipped."""
    file_set = descriptor.FileSet()
    file_set.files = [self.MakeFileDescriptor(u'standalone'),
                      self.MakeFileDescriptor(u'root.nested'),
                      self.MakeFileDescriptor(u'root.nested.nested'),
                      descriptor.describe_file(descriptor),
    ]

    root = new.module('root')
    nested = new.module('root.nested')
    root.nested = nested
    modules = {
        'root': root,
        'root.nested': nested,
        'protorpc.descriptor': descriptor,
    }

    definition.import_file_set(file_set, modules=modules)

    self.assertEquals(root, modules['root'])
    self.assertEquals(nested, modules['root.nested'])
    self.assertEquals(nested.nested, modules['root.nested.nested'])
    self.assertEquals(descriptor, modules['protorpc.descriptor'])

    self.assertEquals(file_set,
                      descriptor.describe_file_set(
                          [modules['standalone'],
                           modules['root.nested'],
                           modules['root.nested.nested'],
                           modules['protorpc.descriptor'],
                          ]))


if __name__ == '__main__':
  unittest.main()
