#!/usr/bin/python
# Copyright (c) 2009 Las Cumbres Observatory (www.lcogt.net)
# Copyright (c) 2010 Jan Dittberner
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
'''
service.py - A wrapper around protocol buffer rpc services.

Allows protocol buffer rpc services to be called directly by method name
either synchronously or asynchronously without having to worry about the
setup details.  RpcService is a class that constructs a service instance
from a service stub in a protoc generated pb2 module.  RpcThread is a
simple thread handling the RPC call outside of the main thread.


Authors: Zach Walker (zwalker@lcogt.net)
         Jan Dittberner (jan@dittberner.info)

Nov 2009, Nov 2010
'''

# Standard library imports
from time import time
import threading

# Third party imports
from protobuf.socketrpc.channel import SocketRpcChannel
from protobuf.socketrpc.error import RpcError

# Module imports
from protobuf.socketrpc import logger

log = logger.getLogger(__name__)


class RpcThread(threading.Thread):
    ''' Thread for handling an rpc request '''

    def __init__(self, method, service, controller, request, callback):
        threading.Thread.__init__(self)
        self.method = method
        self.service = service
        self.controller = controller
        self.request = request
        self.callback = callback
        self.setDaemon(True)

    def run(self):
        # Make the RPC call and pass in a object with a
        # function called "run"
        if callable(self.callback):
            # Attempting to do something similar to a JAVA anonymous class
            # Create an instance of a new generic type initialized with a
            # dict that has a run method pointing to the callback function
            self.method(self.service, self.controller, self.request,
                        type("", (), {"run": lambda *args: \
                                      self.callback(self.request, args[1])})())
        else:
            # Assuming callback has a run method
            self.method(self.service, self.controller, self.request,
                        self.callback)


class RpcService(object):
    '''
    Class abstracting the Protocol Buffer RPC calls for a supplied
    service stub.
    '''

    def __init__(self, service_stub_class, port, host):
        '''
        Contruct a new ProtoBufRpcRequest and return it.

        Accepted Arguments:
        service_stub_class -- (Service_Stub) The client side RPC
                              stub class produced by protoc from
                              the .proto file
        port -- (Integer) The port on which the service is running
                on the RPC server.
        host -- (String) The hostname or IP address of the server
                running the RPC service.
        '''
        self.service_stub_class = service_stub_class
        self.port = port
        self.host = host

        # Setup the RPC channel
        self.channel = SocketRpcChannel(host=self.host, port=self.port)
        self.service = self.service_stub_class(self.channel)

        # go through service_stub methods and add a wrapper function to
        # this object that will call the method
        for method in service_stub_class.GetDescriptor().methods:
            # Add service methods to the this object
            rpc = lambda request, timeout=None, callback=None, service=self, \
                method=method.name: \
                service.call(service_stub_class.__dict__[method], request,
                             timeout, callback)
            rpc.__doc__ = method.name + ' method of the ' + \
                service_stub_class.DESCRIPTOR.name + ' from the ' + \
                service_stub_class.__module__ + ' module generated by the ' + \
                'protoc compiler.  This method can be called ' + \
                'synchronously by setting timeout -> ms or ' + \
                'asynchrounously by setting callback -> ' + \
                'function(request,response)\n\nSynchronous Example:\n' + \
                '\trequest = ' + method.input_type.name + '()\n' + \
                '\ttry:\n' + \
                '\t#Wait 1000ms for a response\n' + \
                '\t\tresponse = ' + method.name + \
                '(request, timeout=1000)\n' + \
                '\texcept: RpcException\n' + \
                '\t\t#Handle exception\n\n' + \
                'Asynchronous Example:\n' + \
                '\tdef callback(request,response):\n' + \
                '\t\t#Do some stuff\n' + \
                '\trequest = ' + method.input_type.name + '()\n' + \
                '\ttry:\n' + \
                '\t\t' + method.name + '(request, callback=callback)\n' + \
                '\texcept: RpcException\n' + \
                '\t\t#Handle exception\n\n'
            self.__dict__[method.name] = rpc

    def call(self, rpc, request, timeout=None, callback=None):
        '''
        Save the object that has been created and return the response.
        Will timeout after timeout ms if response has not been
        received.  The timeout arg is only used for asynch requests.
        If a callback has been supplied the timeout arg is not used.
        The response value will be returned for a synch request but
        nothing will be returned for an asynch request.

        Accepted Arguments:
        timeout -- (Integer) ms to wait for a response before returning
        '''
        # Define local callback function to handle RPC response
        # and initialize result dict
        result = {'done': False, 'response': None}

        def synch_callback(request, response):
            result['response'] = response
            result['done'] = True
            result['error_msg'] = ''
            result['success'] = True

        # If no callback has been passed in then this is meant to be
        # synchronous
        if callback == None:
            rpc_callback = synch_callback
        else:
            if ((not callable(callback) and
                 (callback.__class__.__dict__.get('run') == None or
                  callback.run.func_code.co_argcount < 2)) or
                   (callable(callback) and
                    callback.func_code.co_argcount < 2)):
                raise Exception("callback must be a callable with signature " +
                                "callback(request, response, ...) or an " +
                                "object with a callable run function with " +
                                "the same signature")
            rpc_callback = callback

        # Create a controller for this call
        controller = self.channel.newController()

        # Spawn a new thread to wait for the callback so this can return
        # immediately if an asynch callback has been requested
        rpc_thread = RpcThread(rpc, self.service, controller,
                               request, rpc_callback)
        rpc_thread.start()
        # If a callback has been passed in return
        if rpc_callback == callback:
            return
        else:
            if timeout == None:
                timeout = 100
        end = time() + (timeout / 1000)

        # Wait for timeout or thread to exit indicating call has returned
        rpc_thread.join(timeout)

        if time() >= end and not result['done']:
            raise RpcError('request timed out')

        return result['response']
