#!/usr/bin/env python3
# -*- mode: python -*-
# -*- coding: utf-8 -*-

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""RPC/IPC support."""

import abc
import http.client
import http.server
import io
import logging
import os
import socketserver

from avro import io as avro_io
from avro import protocol
from avro import schema

logger = logging.getLogger(__name__)

# ------------------------------------------------------------------------------
# Constants

def LoadResource(name):
  dir_path = os.path.dirname(__file__)
  rsrc_path = os.path.join(dir_path, name)
  with open(rsrc_path, 'r') as f:
    return f.read()


# Handshake schema is pulled in during build
HANDSHAKE_REQUEST_SCHEMA_JSON = LoadResource('HandshakeRequest.avsc')
HANDSHAKE_RESPONSE_SCHEMA_JSON = LoadResource('HandshakeResponse.avsc')

HANDSHAKE_REQUEST_SCHEMA = schema.Parse(HANDSHAKE_REQUEST_SCHEMA_JSON)
HANDSHAKE_RESPONSE_SCHEMA = schema.Parse(HANDSHAKE_RESPONSE_SCHEMA_JSON)

HANDSHAKE_REQUESTOR_WRITER = avro_io.DatumWriter(HANDSHAKE_REQUEST_SCHEMA)
HANDSHAKE_REQUESTOR_READER = avro_io.DatumReader(HANDSHAKE_RESPONSE_SCHEMA)
HANDSHAKE_RESPONDER_WRITER = avro_io.DatumWriter(HANDSHAKE_RESPONSE_SCHEMA)
HANDSHAKE_RESPONDER_READER = avro_io.DatumReader(HANDSHAKE_REQUEST_SCHEMA)

META_SCHEMA = schema.Parse('{"type": "map", "values": "bytes"}')
META_WRITER = avro_io.DatumWriter(META_SCHEMA)
META_READER = avro_io.DatumReader(META_SCHEMA)

SYSTEM_ERROR_SCHEMA = schema.Parse('["string"]')

AVRO_RPC_MIME = 'avro/binary'

# protocol cache

# Map: remote name -> remote MD5 hash
_REMOTE_HASHES = {}

# Decoder/encoder for a 32 bits big-endian integer.
UINT32_BE = avro_io.STRUCT_INT

# Default size of the buffers use to frame messages:
BUFFER_SIZE = 8192


# ------------------------------------------------------------------------------
# Exceptions


class AvroRemoteException(schema.AvroException):
  """
  Raised when an error message is sent by an Avro requestor or responder.
  """
  def __init__(self, fail_msg=None):
    schema.AvroException.__init__(self, fail_msg)

class ConnectionClosedException(schema.AvroException):
  pass


# ------------------------------------------------------------------------------
# Base IPC Classes (Requestor/Responder)


class BaseRequestor(object, metaclass=abc.ABCMeta):
  """Base class for the client side of a protocol interaction."""

  def __init__(self, local_protocol, transceiver):
    """Initializes a new requestor object.

    Args:
      local_protocol: Avro Protocol describing the messages sent and received.
      transceiver: Transceiver instance to channel messages through.
    """
    self._local_protocol = local_protocol
    self._transceiver = transceiver
    self._remote_protocol = None
    self._remote_hash = None
    self._send_protocol = None

  @property
  def local_protocol(self):
    """Returns: the Avro Protocol describing the messages sent and received."""
    return self._local_protocol

  @property
  def transceiver(self):
    """Returns: the underlying channel used by this requestor."""
    return self._transceiver

  @abc.abstractmethod
  def _IssueRequest(self, call_request, message_name, request_datum):
    """TODO: Document this method.

    Args:
      call_request: ???
      message_name: Name of the message.
      request_datum: ???
    Returns:
      ???
    """
    raise Error('Abstract method')

  def Request(self, message_name, request_datum):
    """Writes a request message and reads a response or error message.

    Args:
      message_name: Name of the IPC method.
      request_datum: IPC request.
    Returns:
      The IPC response.
    """
    # build handshake and call request
    buffer_writer = io.BytesIO()
    buffer_encoder = avro_io.BinaryEncoder(buffer_writer)
    self._WriteHandshakeRequest(buffer_encoder)
    self._WriteCallRequest(message_name, request_datum, buffer_encoder)

    # send the handshake and call request; block until call response
    call_request = buffer_writer.getvalue()
    return self._IssueRequest(call_request, message_name, request_datum)

  def _WriteHandshakeRequest(self, encoder):
    """Emits the handshake request.

    Args:
      encoder: Encoder to write the handshake request into.
    """
    local_hash = self._local_protocol.md5

    # if self._remote_hash is None:
    #   remote_name = self.transceiver.remote_name
    #   self._remote_hash = _REMOTE_HASHES.get(remote_name)

    if self._remote_hash is None:
      self._remote_hash = local_hash
      self._remote_protocol = self._local_protocol

    request_datum = {
      'clientHash': local_hash,
      'serverHash': self._remote_hash,
    }
    if self._send_protocol:
      request_datum['clientProtocol'] = str(self._local_protocol)

    logger.info('Sending handshake request: %s', request_datum)
    HANDSHAKE_REQUESTOR_WRITER.write(request_datum, encoder)

  def _WriteCallRequest(self, message_name, request_datum, encoder):
    """
    The format of a call request is:
      * request metadata, a map with values of type bytes
      * the message name, an Avro string, followed by
      * the message parameters. Parameters are serialized according to
        the message's request declaration.
    """
    # request metadata (not yet implemented)
    request_metadata = {}
    META_WRITER.write(request_metadata, encoder)

    # Identify message to send:
    message = self.local_protocol.message_map.get(message_name)
    if message is None:
      raise schema.AvroException('Unknown message: %s' % message_name)
    encoder.write_utf8(message.name)

    # message parameters
    self._WriteRequest(message.request, request_datum, encoder)

  def _WriteRequest(self, request_schema, request_datum, encoder):
    logger.info('writing request: %s', request_datum)
    datum_writer = avro_io.DatumWriter(request_schema)
    datum_writer.write(request_datum, encoder)

  def _ReadHandshakeResponse(self, decoder):
    """Reads and processes the handshake response message.

    Args:
      decoder: Decoder to read messages from.
    Returns:
      call-response exists (boolean) ???
    Raises:
      schema.AvroException on ???
    """
    handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder)
    logger.info('Processing handshake response: %s', handshake_response)
    match = handshake_response['match']
    if match == 'BOTH':
      # Both client and server protocol hashes match:
      self._send_protocol = False
      return True

    elif match == 'CLIENT':
      # Client's side hash mismatch:
      self._remote_protocol = \
          protocol.Parse(handshake_response['serverProtocol'])
      self._remote_hash = handshake_response['serverHash']
      self._send_protocol = False
      return True

    elif match == 'NONE':
      # Neither client nor server match:
      self._remote_protocol = \
          protocol.Parse(handshake_response['serverProtocol'])
      self._remote_hash = handshake_response['serverHash']
      self._send_protocol = True
      return False
    else:
      raise schema.AvroException('handshake_response.match=%r' % match)

  def _ReadCallResponse(self, message_name, decoder):
    """Reads and processes a method call response.

    The format of a call response is:
      - response metadata, a map with values of type bytes
      - a one-byte error flag boolean, followed by either:
        - if the error flag is false,
          the message response, serialized per the message's response schema.
        - if the error flag is true,
          the error, serialized per the message's error union schema.

    Args:
      message_name:
      decoder:
    Returns:
      ???
    Raises:
      schema.AvroException on ???
    """
    # response metadata
    response_metadata = META_READER.read(decoder)

    # remote response schema
    remote_message_schema = self._remote_protocol.message_map.get(message_name)
    if remote_message_schema is None:
      raise schema.AvroException('Unknown remote message: %s' % message_name)

    # local response schema
    local_message_schema = self._local_protocol.message_map.get(message_name)
    if local_message_schema is None:
      raise schema.AvroException('Unknown local message: %s' % message_name)

    # error flag
    if not decoder.read_boolean():
      writer_schema = remote_message_schema.response
      reader_schema = local_message_schema.response
      return self._ReadResponse(writer_schema, reader_schema, decoder)
    else:
      writer_schema = remote_message_schema.errors
      reader_schema = local_message_schema.errors
      raise self._ReadError(writer_schema, reader_schema, decoder)

  def _ReadResponse(self, writer_schema, reader_schema, decoder):
    datum_reader = avro_io.DatumReader(writer_schema, reader_schema)
    result = datum_reader.read(decoder)
    return result

  def _ReadError(self, writer_schema, reader_schema, decoder):
    datum_reader = avro_io.DatumReader(writer_schema, reader_schema)
    return AvroRemoteException(datum_reader.read(decoder))


class Requestor(BaseRequestor):
  """Concrete requestor implementation."""

  def _IssueRequest(self, call_request, message_name, request_datum):
    call_response = self.transceiver.Transceive(call_request)

    # process the handshake and call response
    buffer_decoder = avro_io.BinaryDecoder(io.BytesIO(call_response))
    call_response_exists = self._ReadHandshakeResponse(buffer_decoder)
    if call_response_exists:
      return self._ReadCallResponse(message_name, buffer_decoder)
    else:
      return self.Request(message_name, request_datum)


# ------------------------------------------------------------------------------


class Responder(object, metaclass=abc.ABCMeta):
  """Base class for the server side of a protocol interaction."""

  def __init__(self, local_protocol):
    self._local_protocol = local_protocol
    self._local_hash = self._local_protocol.md5
    self._protocol_cache = {}

    self.set_protocol_cache(self._local_hash, self._local_protocol)

  @property
  def local_protocol(self):
    return self._local_protocol

  # utility functions to manipulate protocol cache
  def get_protocol_cache(self, hash):
    return self._protocol_cache.get(hash)

  def set_protocol_cache(self, hash, protocol):
    self._protocol_cache[hash] = protocol

  def Respond(self, call_request):
    """Entry point to process one procedure call.

    Args:
      call_request: Serialized procedure call request.
    Returns:
      Serialized procedure call response.
    Raises:
      ???
    """
    buffer_reader = io.BytesIO(call_request)
    buffer_decoder = avro_io.BinaryDecoder(buffer_reader)
    buffer_writer = io.BytesIO()
    buffer_encoder = avro_io.BinaryEncoder(buffer_writer)
    error = None
    response_metadata = {}

    try:
      remote_protocol = self._ProcessHandshake(buffer_decoder, buffer_encoder)
      # handshake failure
      if remote_protocol is None:
        return buffer_writer.getvalue()

      # read request using remote protocol
      request_metadata = META_READER.read(buffer_decoder)
      remote_message_name = buffer_decoder.read_utf8()

      # get remote and local request schemas so we can do
      # schema resolution (one fine day)
      remote_message = remote_protocol.message_map.get(remote_message_name)
      if remote_message is None:
        fail_msg = 'Unknown remote message: %s' % remote_message_name
        raise schema.AvroException(fail_msg)
      local_message = self.local_protocol.message_map.get(remote_message_name)
      if local_message is None:
        fail_msg = 'Unknown local message: %s' % remote_message_name
        raise schema.AvroException(fail_msg)
      writer_schema = remote_message.request
      reader_schema = local_message.request
      request = self._ReadRequest(writer_schema, reader_schema, buffer_decoder)
      logger.info('Processing request: %r', request)

      # perform server logic
      try:
        response = self.Invoke(local_message, request)
      except AvroRemoteException as exn:
        error = exn
      except Exception as exn:
        error = AvroRemoteException(str(exn))

      # write response using local protocol
      META_WRITER.write(response_metadata, buffer_encoder)
      buffer_encoder.write_boolean(error is not None)
      if error is None:
        writer_schema = local_message.response
        self._WriteResponse(writer_schema, response, buffer_encoder)
      else:
        writer_schema = local_message.errors
        self._WriteError(writer_schema, error, buffer_encoder)
    except schema.AvroException as exn:
      error = AvroRemoteException(str(exn))
      buffer_encoder = avro_io.BinaryEncoder(io.StringIO())
      META_WRITER.write(response_metadata, buffer_encoder)
      buffer_encoder.write_boolean(True)
      self._WriteError(SYSTEM_ERROR_SCHEMA, error, buffer_encoder)
    return buffer_writer.getvalue()

  def _ProcessHandshake(self, decoder, encoder):
    """Processes an RPC handshake.

    Args:
      decoder: Where to read from.
      encoder: Where to write to.
    Returns:
      The requested Protocol.
    """
    handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
    logger.info('Processing handshake request: %s', handshake_request)

    # determine the remote protocol
    client_hash = handshake_request.get('clientHash')
    client_protocol = handshake_request.get('clientProtocol')
    remote_protocol = self.get_protocol_cache(client_hash)
    if remote_protocol is None and client_protocol is not None:
      remote_protocol = protocol.Parse(client_protocol)
      self.set_protocol_cache(client_hash, remote_protocol)

    # evaluate remote's guess of the local protocol
    server_hash = handshake_request.get('serverHash')

    handshake_response = {}
    if self._local_hash == server_hash:
      if remote_protocol is None:
        handshake_response['match'] = 'NONE'
      else:
        handshake_response['match'] = 'BOTH'
    else:
      if remote_protocol is None:
        handshake_response['match'] = 'NONE'
      else:
        handshake_response['match'] = 'CLIENT'

    if handshake_response['match'] != 'BOTH':
      handshake_response['serverProtocol'] = str(self.local_protocol)
      handshake_response['serverHash'] = self._local_hash

    logger.info('Handshake response: %s', handshake_response)
    HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
    return remote_protocol

  @abc.abstractmethod
  def Invoke(self, local_message, request):
    """Processes one procedure call.

    Args:
      local_message: Avro message specification.
      request: Call request.
    Returns:
      Call response.
    Raises:
      ???
    """
    raise Error('abtract method')

  def _ReadRequest(self, writer_schema, reader_schema, decoder):
    datum_reader = avro_io.DatumReader(writer_schema, reader_schema)
    return datum_reader.read(decoder)

  def _WriteResponse(self, writer_schema, response_datum, encoder):
    datum_writer = avro_io.DatumWriter(writer_schema)
    datum_writer.write(response_datum, encoder)

  def _WriteError(self, writer_schema, error_exception, encoder):
    datum_writer = avro_io.DatumWriter(writer_schema)
    datum_writer.write(str(error_exception), encoder)


# ------------------------------------------------------------------------------
# Framed message


class FramedReader(object):
  """Wrapper around a file-like object to read framed data."""

  def __init__(self, reader):
    self._reader = reader

  def Read(self):
    """Reads one message from the configured reader.

    Returns:
      The message, as bytes.
    """
    message = io.BytesIO()
    # Read and append frames until we encounter a 0-size frame:
    while self._ReadFrame(message) > 0: pass
    return message.getvalue()

  def _ReadFrame(self, message):
    """Reads and appends one frame into the given message bytes.

    Args:
      message: Message to append the frame to.
    Returns:
      Size of the frame that was read.
      The empty frame (size 0) indicates the end of a message.
    """
    frame_size = self._ReadInt32()
    remaining = frame_size
    while remaining > 0:
      data_bytes = self._reader.read(remaining)
      if len(data_bytes) == 0:
        raise ConnectionClosedException(
            'FramedReader: expecting %d more bytes in frame of size %d, got 0.'
            % (remaining, frame_size))
      message.write(data_bytes)
      remaining -= len(data_bytes)
    return frame_size

  def _ReadInt32(self):
    encoded = self._reader.read(UINT32_BE.size)
    if len(encoded) != UINT32_BE.size:
      raise ConnectionClosedException('Invalid header: %r' % encoded)
    return UINT32_BE.unpack(encoded)[0]


class FramedWriter(object):
  """Wrapper around a file-like object to write framed data."""

  def __init__(self, writer):
    self._writer = writer

  def Write(self, message):
    """Writes a message.

    Message is chunked into sequences of frames terminated by an empty frame.

    Args:
      message: Message to write, as bytes.
    """
    while len(message) > 0:
      chunk_size = max(BUFFER_SIZE, len(message))
      chunk = message[:chunk_size]
      self._WriteBuffer(chunk)
      message = message[chunk_size:]

    # A message is always terminated by a zero-length buffer.
    self._WriteUnsignedInt32(0)

  def _WriteBuffer(self, chunk):
    self._WriteUnsignedInt32(len(chunk))
    self._writer.write(chunk)

  def _WriteUnsignedInt32(self, uint32):
    self._writer.write(UINT32_BE.pack(uint32))


# ------------------------------------------------------------------------------
# Transceiver (send/receive channel)


class Transceiver(object, metaclass=abc.ABCMeta):
  @abc.abstractproperty
  def remote_name(self):
    pass

  @abc.abstractmethod
  def ReadMessage(self):
    """Reads a single message from the channel.

    Blocks until a message can be read.

    Returns:
      The message read from the channel.
    """
    pass

  @abc.abstractmethod
  def WriteMessage(self, message):
    """Writes a message into the channel.

    Blocks until the message has been written.

    Args:
      message: Message to write.
    """
    pass

  def Transceive(self, request):
    """Processes a single request-reply interaction.

    Synchronous request-reply interaction.

    Args:
      request: Request message.
    Returns:
      The reply message.
    """
    self.WriteMessage(request)
    result = self.ReadMessage()
    return result

  def Close(self):
    """Closes this transceiver."""
    pass


class HTTPTransceiver(Transceiver):
  """HTTP-based transceiver implementation."""

  def __init__(self, host, port, req_resource='/'):
    """Initializes a new HTTP transceiver.

    Args:
      host: Name or IP address of the remote host to interact with.
      port: Port the remote server is listening on.
      req_resource: Optional HTTP resource path to use, '/' by default.
    """
    self._req_resource = req_resource
    self._conn = http.client.HTTPConnection(host, port)
    self._conn.connect()
    self._remote_name = self._conn.sock.getsockname()

  @property
  def remote_name(self):
    return self._remote_name

  def ReadMessage(self):
    response = self._conn.getresponse()
    response_reader = FramedReader(response)
    framed_message = response_reader.Read()
    response.read()    # ensure we're ready for subsequent requests
    return framed_message

  def WriteMessage(self, message):
    req_method = 'POST'
    req_headers = {'Content-Type': AVRO_RPC_MIME}

    bio = io.BytesIO()
    req_body_buffer = FramedWriter(bio)
    req_body_buffer.Write(message)
    req_body = bio.getvalue()

    self._conn.request(req_method, self._req_resource, req_body, req_headers)

  def Close(self):
    self._conn.close()
    self._conn = None


# ------------------------------------------------------------------------------
# Server Implementations


def _MakeHandlerClass(responder):
  class AvroHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
    def do_POST(self):
      reader = FramedReader(self.rfile)
      call_request = reader.Read()
      logger.info('Serialized request: %r', call_request)
      call_response = responder.Respond(call_request)
      logger.info('Serialized response: %r', call_response)

      self.send_response(200)
      self.send_header('Content-type', AVRO_RPC_MIME)
      self.end_headers()

      framed_writer = FramedWriter(self.wfile)
      framed_writer.Write(call_response)
      self.wfile.flush()
      logger.info('Response sent')

  return AvroHTTPRequestHandler


class MultiThreadedHTTPServer(
    socketserver.ThreadingMixIn,
    http.server.HTTPServer,
):
  """Multi-threaded HTTP server."""
  pass


class AvroIpcHttpServer(MultiThreadedHTTPServer):
  """Avro IPC server implemented on top of an HTTP server."""

  def __init__(self, interface, port, responder):
    """Initializes a new Avro IPC server.

    Args:
      interface: Interface the server listens on, eg. 'localhost' or '0.0.0.0'.
      port: TCP port the server listens on, eg. 8000.
      responder: Responder implementation to handle RPCs.
    """
    super(AvroIpcHttpServer, self).__init__(
        server_address=(interface, port),
        RequestHandlerClass=_MakeHandlerClass(responder),
    )


if __name__ == '__main__':
  raise Exception('Not a standalone module')
