1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
""" gRPC server for spamcheck service. """
from concurrent import futures
import grpc
from grpc_reflection.v1alpha import reflection
from vyper import v
import api.v1.health_pb2 as health
import api.v1.health_pb2_grpc as health_grpc
import api.v1.spamcheck_pb2 as spam
import api.v1.spamcheck_pb2_grpc as spam_grpc
from app import logger, ValidationError
from app.spammable import generic, issue, snippet
from server.interceptors import LoggingInterceptor, CorrelationIDInterceptor
log = logger.logger
# The method names can't be camel case due to generated gRPC code.
#
# pylint: disable=invalid-name
class SpamCheckServicer(spam_grpc.SpamcheckServiceServicer):
"""Handler for gRPC routes."""
def CheckForSpamGeneric(
self, request: spam.Generic, context: grpc.ServicerContext
) -> spam.SpamVerdict:
"""Route for generic spam."""
return generic.Generic(request, context).verdict()
def CheckForSpamIssue(
self, request: spam.Issue, context: grpc.ServicerContext
) -> spam.SpamVerdict:
"""Route for issue spam."""
try:
return issue.Issue(request, context).verdict()
except ValidationError as ex:
fields = {
"metric": "spamcheck_validation_errors",
"correlation_id": context.correlation_id,
}
log.warning("Invalid issue", extra=fields)
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(str(ex))
return spam.SpamVerdict()
def CheckForSpamSnippet(
self, request: spam.Snippet, context: grpc.ServicerContext
) -> spam.SpamVerdict:
"""Route for snippet spam."""
try:
return snippet.Snippet(request, context).verdict()
except ValidationError as ex:
fields = {
"metric": "spamcheck_validation_errors",
"correlation_id": context.correlation_id,
}
log.warning("Invalid snippet", extra=fields)
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(str(ex))
return spam.SpamVerdict()
class HealthServicer(health_grpc.HealthServicer):
"""Handler for gRPC health check."""
def Check(self, request: health.HealthCheckRequest, context: grpc.ServicerContext):
"""Urnary health check method"""
return health.HealthCheckResponse(status=health.HealthCheckResponse.SERVING)
def Watch(self, request, context):
"""Streaming health check method"""
yield health.HealthCheckResponse(status=health.HealthCheckResponse.SERVING)
def _fetch_certificate():
tls_certificate = v.get_string("tls_certificate")
tls_private_key = v.get_string("tls_private_key")
with open(tls_private_key, "rb") as file:
private_key = file.read()
with open(tls_certificate, "rb") as file:
certificate_chain = file.read()
return (private_key, certificate_chain)
def _server(addr: str, tls: bool) -> grpc.Server:
interceptors = [CorrelationIDInterceptor(), LoggingInterceptor()]
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=50), interceptors=interceptors
)
if tls:
creds = grpc.ssl_server_credentials(
((_fetch_certificate()),),
root_certificates=None,
require_client_auth=False,
)
server.add_secure_port(addr, creds)
else:
log.warning("TLS certificates not found. Defaulting to insecure channel")
server.add_insecure_port(addr)
return server
def serve():
"""Start the gRPC server."""
env = v.get_string("env")
tls_enabled = v.get_bool("tls_enabled")
addr = v.get_string("grpc_addr")
if addr.isdecimal():
addr = f"0.0.0.0:{addr}"
server = _server(addr, tls_enabled)
spam_grpc.add_SpamcheckServiceServicer_to_server(SpamCheckServicer(), server)
health_grpc.add_HealthServicer_to_server(HealthServicer(), server)
# Disable gRPC reflection in production environments
if env != "production":
service_names = (
spam.DESCRIPTOR.services_by_name["SpamcheckService"].full_name,
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(service_names, server)
server.start()
log.info("gRPC server started", extra={"tls_enabled": tls_enabled, "addr": addr})
try:
server.wait_for_termination()
except KeyboardInterrupt:
log.info("shutting down gRPC server")
server.stop(None)
if __name__ == "__main__":
serve()
|