File: server.py

package info (click to toggle)
ruby-spamcheck 1.10.1-2
  • links: PTS, VCS
  • area: contrib
  • in suites: sid, trixie
  • size: 668 kB
  • sloc: python: 1,261; ruby: 484; makefile: 54; sh: 13
file content (137 lines) | stat: -rw-r--r-- 4,606 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
""" gRPC server for spamcheck service. """
from concurrent import futures

import grpc
from grpc_reflection.v1alpha import reflection
from vyper import v

import api.v1.health_pb2 as health
import api.v1.health_pb2_grpc as health_grpc
import api.v1.spamcheck_pb2 as spam
import api.v1.spamcheck_pb2_grpc as spam_grpc
from app import logger, ValidationError
from app.spammable import generic, issue, snippet
from server.interceptors import LoggingInterceptor, CorrelationIDInterceptor

log = logger.logger


# The method names can't be camel case due to generated gRPC code.
#
# pylint: disable=invalid-name
class SpamCheckServicer(spam_grpc.SpamcheckServiceServicer):
    """Handler for gRPC routes."""

    def CheckForSpamGeneric(
        self, request: spam.Generic, context: grpc.ServicerContext
    ) -> spam.SpamVerdict:
        """Route for generic spam."""

        return generic.Generic(request, context).verdict()

    def CheckForSpamIssue(
        self, request: spam.Issue, context: grpc.ServicerContext
    ) -> spam.SpamVerdict:
        """Route for issue spam."""

        try:
            return issue.Issue(request, context).verdict()
        except ValidationError as ex:
            fields = {
                "metric": "spamcheck_validation_errors",
                "correlation_id": context.correlation_id,
            }
            log.warning("Invalid issue", extra=fields)
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details(str(ex))
            return spam.SpamVerdict()

    def CheckForSpamSnippet(
        self, request: spam.Snippet, context: grpc.ServicerContext
    ) -> spam.SpamVerdict:
        """Route for snippet spam."""

        try:
            return snippet.Snippet(request, context).verdict()
        except ValidationError as ex:
            fields = {
                "metric": "spamcheck_validation_errors",
                "correlation_id": context.correlation_id,
            }
            log.warning("Invalid snippet", extra=fields)
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            context.set_details(str(ex))
            return spam.SpamVerdict()


class HealthServicer(health_grpc.HealthServicer):
    """Handler for gRPC health check."""

    def Check(self, request: health.HealthCheckRequest, context: grpc.ServicerContext):
        """Urnary health check method"""
        return health.HealthCheckResponse(status=health.HealthCheckResponse.SERVING)

    def Watch(self, request, context):
        """Streaming health check method"""
        yield health.HealthCheckResponse(status=health.HealthCheckResponse.SERVING)


def _fetch_certificate():
    tls_certificate = v.get_string("tls_certificate")
    tls_private_key = v.get_string("tls_private_key")
    with open(tls_private_key, "rb") as file:
        private_key = file.read()
    with open(tls_certificate, "rb") as file:
        certificate_chain = file.read()
    return (private_key, certificate_chain)


def _server(addr: str, tls: bool) -> grpc.Server:
    interceptors = [CorrelationIDInterceptor(), LoggingInterceptor()]
    server = grpc.server(
        futures.ThreadPoolExecutor(max_workers=50), interceptors=interceptors
    )
    if tls:
        creds = grpc.ssl_server_credentials(
            ((_fetch_certificate()),),
            root_certificates=None,
            require_client_auth=False,
        )
        server.add_secure_port(addr, creds)
    else:
        log.warning("TLS certificates not found. Defaulting to insecure channel")
        server.add_insecure_port(addr)

    return server


def serve():
    """Start the gRPC server."""
    env = v.get_string("env")
    tls_enabled = v.get_bool("tls_enabled")
    addr = v.get_string("grpc_addr")
    if addr.isdecimal():
        addr = f"0.0.0.0:{addr}"
    server = _server(addr, tls_enabled)
    spam_grpc.add_SpamcheckServiceServicer_to_server(SpamCheckServicer(), server)
    health_grpc.add_HealthServicer_to_server(HealthServicer(), server)

    # Disable gRPC reflection in production environments
    if env != "production":
        service_names = (
            spam.DESCRIPTOR.services_by_name["SpamcheckService"].full_name,
            reflection.SERVICE_NAME,
        )
        reflection.enable_server_reflection(service_names, server)

    server.start()
    log.info("gRPC server started", extra={"tls_enabled": tls_enabled, "addr": addr})
    try:
        server.wait_for_termination()
    except KeyboardInterrupt:
        log.info("shutting down gRPC server")
        server.stop(None)


if __name__ == "__main__":
    serve()