""" gRPC server for spamcheck service. """
from concurrent import futures

import grpc
from grpc_reflection.v1alpha import reflection
from vyper import v

import api.v1.health_pb2 as health
import api.v1.health_pb2_grpc as health_grpc
import api.v1.spamcheck_pb2 as spam
import api.v1.spamcheck_pb2_grpc as spam_grpc
from app import logger, ValidationError
from app.spammable import generic, issue, snippet
from server.interceptors import LoggingInterceptor, CorrelationIDInterceptor

log = logger.logger


# The method names can't be camel case due to generated gRPC code.
#
# pylint: disable=invalid-name
class SpamCheckServicer(spam_grpc.SpamcheckServiceServicer):
    """Handler for gRPC routes."""

    def CheckForSpamGeneric(
        self, request: spam.Generic, context: grpc.ServicerContext
    ) -> spam.SpamVerdict:
        """Route for generic spam."""

        return generic.Generic(request, context).verdict()

    def CheckForSpamIssue(
        self, request: spam.Issue, context: grpc.ServicerContext
    ) -> spam.SpamVerdict:
        """Route for issue spam."""

        try:
            return issue.Issue(request, context).verdict()
        except ValidationError as ex:
            fields = {
                "metric": "spamcheck_validation_errors",
                "correlation_id": context.correlation_id,
            }
            log.warning("Invalid issue", extra=fields)
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details(str(ex))
            return spam.SpamVerdict()

    def CheckForSpamSnippet(
        self, request: spam.Snippet, context: grpc.ServicerContext
    ) -> spam.SpamVerdict:
        """Route for snippet spam."""

        try:
            return snippet.Snippet(request, context).verdict()
        except ValidationError as ex:
            fields = {
                "metric": "spamcheck_validation_errors",
                "correlation_id": context.correlation_id,
            }
            log.warning("Invalid snippet", extra=fields)
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details(str(ex))
            return spam.SpamVerdict()


class HealthServicer(health_grpc.HealthServicer):
    """Handler for gRPC health check."""

    def Check(self, request: health.HealthCheckRequest, context: grpc.ServicerContext):
        """Urnary health check method"""
        return health.HealthCheckResponse(status=health.HealthCheckResponse.SERVING)

    def Watch(self, request, context):
        """Streaming health check method"""
        yield health.HealthCheckResponse(status=health.HealthCheckResponse.SERVING)


def _fetch_certificate():
    tls_certificate = v.get_string("tls_certificate")
    tls_private_key = v.get_string("tls_private_key")
    with open(tls_private_key, "rb") as file:
        private_key = file.read()
    with open(tls_certificate, "rb") as file:
        certificate_chain = file.read()
    return (private_key, certificate_chain)


def _server(addr: str, tls: bool) -> grpc.Server:
    interceptors = [CorrelationIDInterceptor(), LoggingInterceptor()]
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=50), interceptors=interceptors
    )
    if tls:
        creds = grpc.ssl_server_credentials(
            ((_fetch_certificate()),),
            root_certificates=None,
            require_client_auth=False,
        )
        server.add_secure_port(addr, creds)
    else:
        log.warning("TLS certificates not found. Defaulting to insecure channel")
        server.add_insecure_port(addr)

    return server


def serve():
    """Start the gRPC server."""
    env = v.get_string("env")
    tls_enabled = v.get_bool("tls_enabled")
    addr = v.get_string("grpc_addr")
    if addr.isdecimal():
        addr = f"0.0.0.0:{addr}"
    server = _server(addr, tls_enabled)
    spam_grpc.add_SpamcheckServiceServicer_to_server(SpamCheckServicer(), server)
    health_grpc.add_HealthServicer_to_server(HealthServicer(), server)

    # Disable gRPC reflection in production environments
    if env != "production":
        service_names = (
            spam.DESCRIPTOR.services_by_name["SpamcheckService"].full_name,
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(service_names, server)

    server.start()
    log.info("gRPC server started", extra={"tls_enabled": tls_enabled, "addr": addr})
    try:
        server.wait_for_termination()
    except KeyboardInterrupt:
        log.info("shutting down gRPC server")
        server.stop(None)


if __name__ == "__main__":
    serve()
