#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import httplib
import os
import socket
import unittest

from protorpc import messages
from protorpc import protobuf
from protorpc import protojson
from protorpc import remote
from protorpc import test_util
from protorpc import transport
from protorpc import webapp_test_util
from protorpc.wsgi import util as wsgi_util

import mox

package = 'transport_test'


class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
                          test_util.TestCase):

  MODULE = transport


class Message(messages.Message):

  value = messages.StringField(1)


class Service(remote.Service):

  @remote.method(Message, Message)
  def method(self, request):
    pass


# Remove when RPC is no longer subclasses.
class TestRpc(transport.Rpc):

  waited = False

  def _wait_impl(self):
    self.waited = True


class RpcTest(test_util.TestCase):

  def setUp(self):
    self.request = Message(value=u'request')
    self.response = Message(value=u'response')
    self.status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR,
                                   error_message='an error',
                                   error_name='blam')

    self.rpc = TestRpc(self.request)

  def testConstructor(self):
    self.assertEquals(self.request, self.rpc.request)
    self.assertEquals(remote.RpcState.RUNNING, self.rpc.state)
    self.assertEquals(None, self.rpc.error_message)
    self.assertEquals(None, self.rpc.error_name)

  def response(self):
    self.assertFalse(self.rpc.waited)
    self.assertEquals(None, self.rpc.response)
    self.assertTrue(self.rpc.waited)

  def testSetResponse(self):
    self.rpc.set_response(self.response)

    self.assertEquals(self.request, self.rpc.request)
    self.assertEquals(remote.RpcState.OK, self.rpc.state)
    self.assertEquals(self.response, self.rpc.response)
    self.assertEquals(None, self.rpc.error_message)
    self.assertEquals(None, self.rpc.error_name)

  def testSetResponseAlreadySet(self):
    self.rpc.set_response(self.response)

    self.assertRaisesWithRegexpMatch(
      transport.RpcStateError,
      'RPC must be in RUNNING state to change to OK',
      self.rpc.set_response,
      self.response)

  def testSetResponseAlreadyError(self):
    self.rpc.set_status(self.status)

    self.assertRaisesWithRegexpMatch(
      transport.RpcStateError,
      'RPC must be in RUNNING state to change to OK',
      self.rpc.set_response,
      self.response)

  def testSetStatus(self):
    self.rpc.set_status(self.status)

    self.assertEquals(self.request, self.rpc.request)
    self.assertEquals(remote.RpcState.APPLICATION_ERROR, self.rpc.state)
    self.assertEquals('an error', self.rpc.error_message)
    self.assertEquals('blam', self.rpc.error_name)
    self.assertRaisesWithRegexpMatch(remote.ApplicationError,
                                     'an error',
                                     getattr, self.rpc, 'response')

  def testSetStatusAlreadySet(self):
    self.rpc.set_response(self.response)

    self.assertRaisesWithRegexpMatch(
      transport.RpcStateError,
      'RPC must be in RUNNING state to change to OK',
      self.rpc.set_response,
      self.response)

  def testSetNonMessage(self):
    self.assertRaisesWithRegexpMatch(
      TypeError,
      'Expected Message type, received 10',
      self.rpc.set_response,
      10)

  def testSetStatusAlreadyError(self):
    self.rpc.set_status(self.status)

    self.assertRaisesWithRegexpMatch(
      transport.RpcStateError,
      'RPC must be in RUNNING state to change to OK',
      self.rpc.set_response,
      self.response)

  def testSetUninitializedStatus(self):
    self.assertRaises(messages.ValidationError,
                      self.rpc.set_status,
                      remote.RpcStatus())


class TransportTest(test_util.TestCase):

  def setUp(self):
    remote.Protocols.set_default(remote.Protocols.new_default())

  def do_test(self, protocol, trans):
    request = Message()
    request.value = u'request'

    response = Message()
    response.value = u'response'

    encoded_request = protocol.encode_message(request)
    encoded_response = protocol.encode_message(response)

    self.assertEquals(protocol, trans.protocol)

    received_rpc = [None]
    def transport_rpc(remote, rpc_request):
      self.assertEquals(remote, Service.method.remote)
      self.assertEquals(request, rpc_request)
      rpc = TestRpc(request)
      rpc.set_response(response)
      return rpc
    trans._start_rpc = transport_rpc

    rpc = trans.send_rpc(Service.method.remote, request)
    self.assertEquals(response, rpc.response)

  def testDefaultProtocol(self):
    trans = transport.Transport()
    self.do_test(protobuf, trans)
    self.assertEquals(protobuf, trans.protocol_config.protocol)
    self.assertEquals('default', trans.protocol_config.name)

  def testAlternateProtocol(self):
    trans = transport.Transport(protocol=protojson)
    self.do_test(protojson, trans)
    self.assertEquals(protojson, trans.protocol_config.protocol)
    self.assertEquals('default', trans.protocol_config.name)

  def testProtocolConfig(self):
    protocol_config = remote.ProtocolConfig(
      protojson, 'protoconfig', 'image/png')
    trans = transport.Transport(protocol=protocol_config)
    self.do_test(protojson, trans)
    self.assertTrue(trans.protocol_config is protocol_config)

  def testProtocolByName(self):
    remote.Protocols.get_default().add_protocol(
      protojson, 'png', 'image/png', ())
    trans = transport.Transport(protocol='png')
    self.do_test(protojson, trans)


@remote.method(Message, Message)
def my_method(self, request):
  self.fail('self.my_method should not be directly invoked.')


class FakeConnectionClass(object):

  def __init__(self, mox):
    self.request = mox.CreateMockAnything()
    self.response = mox.CreateMockAnything()


class HttpTransportTest(webapp_test_util.WebServerTestBase):

  def setUp(self):
    # Do not need much parent construction functionality.

    self.schema = 'http'
    self.server = None

    self.request = Message(value=u'The request value')
    self.encoded_request = protojson.encode_message(self.request)

    self.response = Message(value=u'The response value')
    self.encoded_response = protojson.encode_message(self.response)

  def testCallSucceeds(self):
    self.ResetServer(wsgi_util.static_page(self.encoded_response,
                                           content_type='application/json'))

    rpc = self.connection.send_rpc(my_method.remote, self.request)
    self.assertEquals(self.response, rpc.response)

  def testHttps(self):
    self.schema = 'https'
    self.ResetServer(wsgi_util.static_page(self.encoded_response,
                                           content_type='application/json'))

    # Create a fake https connection function that really just calls http.
    self.used_https = False
    def https_connection(*args, **kwargs):
      self.used_https = True
      return httplib.HTTPConnection(*args, **kwargs)

    original_https_connection = httplib.HTTPSConnection
    httplib.HTTPSConnection = https_connection
    try:
      rpc = self.connection.send_rpc(my_method.remote, self.request)
    finally:
      httplib.HTTPSConnection = original_https_connection
    self.assertEquals(self.response, rpc.response)
    self.assertTrue(self.used_https)

  def testHttpSocketError(self):
    self.ResetServer(wsgi_util.static_page(self.encoded_response,
                                           content_type='application/json'))

    bad_transport = transport.HttpTransport('http://localhost:-1/blar')
    try:
      bad_transport.send_rpc(my_method.remote, self.request)
    except remote.NetworkError as err:
      self.assertTrue(str(err).startswith('Socket error: gaierror ('))
      self.assertEquals(socket.gaierror, type(err.cause))
      self.assertEquals(8, abs(err.cause.args[0]))  # Sign is sys depednent.
    else:
      self.fail('Expected error')

  def testHttpRequestError(self):
    self.ResetServer(wsgi_util.static_page(self.encoded_response,
                                           content_type='application/json'))

    def request_error(*args, **kwargs):
      raise TypeError('Generic Error')
    original_request = httplib.HTTPConnection.request
    httplib.HTTPConnection.request = request_error
    try:
      try:
        self.connection.send_rpc(my_method.remote, self.request)
      except remote.NetworkError as err:
        self.assertEquals('Error communicating with HTTP server', str(err))
        self.assertEquals(TypeError, type(err.cause))
        self.assertEquals('Generic Error', str(err.cause))
      else:
        self.fail('Expected error')
    finally:
      httplib.HTTPConnection.request = original_request

  def testHandleGenericServiceError(self):
    self.ResetServer(wsgi_util.error(httplib.INTERNAL_SERVER_ERROR,
                                     'arbitrary error',
                                     content_type='text/plain'))

    rpc = self.connection.send_rpc(my_method.remote, self.request)
    try:
      rpc.response
    except remote.ServerError as err:
      self.assertEquals('HTTP Error 500: arbitrary error', str(err).strip())
    else:
      self.fail('Expected ServerError')

  def testHandleGenericServiceErrorNoMessage(self):
    self.ResetServer(wsgi_util.error(httplib.NOT_IMPLEMENTED,
                                     ' ',
                                     content_type='text/plain'))

    rpc = self.connection.send_rpc(my_method.remote, self.request)
    try:
      rpc.response
    except remote.ServerError as err:
      self.assertEquals('HTTP Error 501: Not Implemented', str(err).strip())
    else:
      self.fail('Expected ServerError')

  def testHandleStatusContent(self):
    self.ResetServer(wsgi_util.static_page('{"state": "REQUEST_ERROR",'
                                           ' "error_message": "a request error"'
                                           '}',
                                           status=httplib.BAD_REQUEST,
                                           content_type='application/json'))

    rpc = self.connection.send_rpc(my_method.remote, self.request)
    try:
      rpc.response
    except remote.RequestError as err:
      self.assertEquals('a request error', str(err))
    else:
      self.fail('Expected RequestError')

  def testHandleApplicationError(self):
    self.ResetServer(wsgi_util.static_page('{"state": "APPLICATION_ERROR",'
                                           ' "error_message": "an app error",'
                                           ' "error_name": "MY_ERROR_NAME"}',
                                           status=httplib.BAD_REQUEST,
                                           content_type='application/json'))

    rpc = self.connection.send_rpc(my_method.remote, self.request)
    try:
      rpc.response
    except remote.ApplicationError as err:
      self.assertEquals('an app error', str(err))
      self.assertEquals('MY_ERROR_NAME', err.error_name)
    else:
      self.fail('Expected RequestError')

  def testHandleUnparsableErrorContent(self):
    self.ResetServer(wsgi_util.static_page('oops',
                                           status=httplib.BAD_REQUEST,
                                           content_type='application/json'))

    rpc = self.connection.send_rpc(my_method.remote, self.request)
    try:
      rpc.response
    except remote.ServerError as err:
      self.assertEquals('HTTP Error 400: oops', str(err))
    else:
      self.fail('Expected ServerError')

  def testHandleEmptyBadRpcStatus(self):
    self.ResetServer(wsgi_util.static_page('{"error_message": "x"}',
                                           status=httplib.BAD_REQUEST,
                                           content_type='application/json'))

    rpc = self.connection.send_rpc(my_method.remote, self.request)
    try:
      rpc.response
    except remote.ServerError as err:
      self.assertEquals('HTTP Error 400: {"error_message": "x"}', str(err))
    else:
      self.fail('Expected ServerError')

  def testUseProtocolConfigContentType(self):
    expected_content_type = 'image/png'
    def expect_content_type(environ, start_response):
      self.assertEquals(expected_content_type, environ['CONTENT_TYPE'])
      app = wsgi_util.static_page('', content_type=environ['CONTENT_TYPE'])
      return app(environ, start_response)

    self.ResetServer(expect_content_type)

    protocol_config = remote.ProtocolConfig(protojson, 'json', 'image/png')
    self.connection = self.CreateTransport(self.service_url, protocol_config)

    rpc = self.connection.send_rpc(my_method.remote, self.request)
    self.assertEquals(Message(), rpc.response)


class SimpleRequest(messages.Message):

  content = messages.StringField(1)


class SimpleResponse(messages.Message):

  content = messages.StringField(1)
  factory_value = messages.StringField(2)
  remote_host = messages.StringField(3)
  remote_address = messages.StringField(4)
  server_host = messages.StringField(5)
  server_port = messages.IntegerField(6)


class LocalService(remote.Service):

  def __init__(self, factory_value='default'):
    self.factory_value = factory_value

  @remote.method(SimpleRequest, SimpleResponse)
  def call_method(self, request):
    return SimpleResponse(content=request.content,
                          factory_value=self.factory_value,
                          remote_host=self.request_state.remote_host,
                          remote_address=self.request_state.remote_address,
                          server_host=self.request_state.server_host,
                          server_port=self.request_state.server_port)

  @remote.method()
  def raise_totally_unexpected(self, request):
    raise TypeError('Kablam')

  @remote.method()
  def raise_unexpected(self, request):
    raise remote.RequestError('Huh?')

  @remote.method()
  def raise_application_error(self, request):
    raise remote.ApplicationError('App error', 10)


class LocalTransportTest(test_util.TestCase):

  def CreateService(self, factory_value='default'):
    return 

  def testBasicCallWithClass(self):
    stub = LocalService.Stub(transport.LocalTransport(LocalService))
    response = stub.call_method(content='Hello')
    self.assertEquals(SimpleResponse(content='Hello',
                                     factory_value='default',
                                     remote_host=os.uname()[1],
                                     remote_address='127.0.0.1',
                                     server_host=os.uname()[1],
                                     server_port=-1),
                      response)

  def testBasicCallWithFactory(self):
    stub = LocalService.Stub(
      transport.LocalTransport(LocalService.new_factory('assigned')))
    response = stub.call_method(content='Hello')
    self.assertEquals(SimpleResponse(content='Hello',
                                     factory_value='assigned',
                                     remote_host=os.uname()[1],
                                     remote_address='127.0.0.1',
                                     server_host=os.uname()[1],
                                     server_port=-1),
                      response)

  def testTotallyUnexpectedError(self):
    stub = LocalService.Stub(transport.LocalTransport(LocalService))
    self.assertRaisesWithRegexpMatch(
      remote.ServerError,
      'Unexpected error TypeError: Kablam',
      stub.raise_totally_unexpected)

  def testUnexpectedError(self):
    stub = LocalService.Stub(transport.LocalTransport(LocalService))
    self.assertRaisesWithRegexpMatch(
      remote.ServerError,
      'Unexpected error RequestError: Huh?',
      stub.raise_unexpected)

  def testApplicationError(self):
    stub = LocalService.Stub(transport.LocalTransport(LocalService))
    self.assertRaisesWithRegexpMatch(
      remote.ApplicationError,
      'App error',
      stub.raise_application_error)


def main():
  unittest.main()


if __name__ == '__main__':
  main()
