File: curvezmq.py

package info (click to toggle)
python-parsl 2025.01.13%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,072 kB
  • sloc: python: 23,817; makefile: 349; sh: 276; ansic: 45
file content (209 lines) | stat: -rw-r--r-- 7,073 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os
from abc import ABCMeta, abstractmethod
from typing import Optional, Tuple, Union

import zmq
import zmq.auth
from zmq.auth.thread import ThreadAuthenticator


def create_certificates(base_dir: Union[str, os.PathLike]):
    """Create server and client certificates in a private directory.

    This will overwrite existing certificate files.

    Parameters
    ----------
    base_dir : str | os.PathLike
        Parent directory of the private certificates directory.
    """
    cert_dir = os.path.join(base_dir, "certificates")
    os.makedirs(cert_dir, mode=0o700, exist_ok=True)

    zmq.auth.create_certificates(cert_dir, name="server")
    zmq.auth.create_certificates(cert_dir, name="client")

    return cert_dir


def _load_certificate(
    cert_dir: Union[str, os.PathLike], name: str
) -> Tuple[bytes, bytes]:
    if os.stat(cert_dir).st_mode & 0o777 != 0o700:
        raise OSError(f"The certificates directory must be private: {cert_dir}")

    # pyzmq creates secret key files with the '.key_secret' extension
    # Ref: https://github.com/zeromq/pyzmq/blob/ae615d4097ccfbc6b5c17de60355cbe6e00a6065/zmq/auth/certs.py#L73
    secret_key_file = os.path.join(cert_dir, f"{name}.key_secret")
    public_key, secret_key = zmq.auth.load_certificate(secret_key_file)
    if secret_key is None:
        raise ValueError(f"No secret key found in {secret_key_file}")

    return public_key, secret_key


class BaseContext(metaclass=ABCMeta):
    """Base CurveZMQ context"""

    def __init__(self, cert_dir: Optional[Union[str, os.PathLike]]) -> None:
        self.cert_dir = cert_dir
        self._ctx = zmq.Context()

    @property
    def encrypted(self):
        """Indicates whether encryption is enabled.

        False (disabled) when self.cert_dir is set to None.
        """
        return self.cert_dir is not None

    @property
    def closed(self):
        return self._ctx.closed

    @abstractmethod
    def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
        """Create a socket associated with this context.

        This method will apply all necessary certificates and socket options.

        Parameters
        ----------
        socket_type : int
            The socket type, which can be any of the 0MQ socket types: REQ, REP,
            PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, etc.

        args:
            passed to the zmq.Context.socket method.

        kwargs:
            passed to the zmq.Context.socket method.
        """
        ...

    def term(self):
        """Terminate the context."""
        self._ctx.term()

    def destroy(self, linger: Optional[int] = None):
        """Close all sockets associated with this context and then terminate
        the context.

        .. warning::

            destroy involves calling ``zmq_close()``, which is **NOT** threadsafe.
            If there are active sockets in other threads, this must not be called.

        Parameters
        ----------
        linger : int, optional
            If specified, set LINGER on sockets prior to closing them.
        """
        self._ctx.destroy(linger)

    def recreate(self, linger: Optional[int] = None):
        """Destroy then recreate the context.

        Parameters
        ----------
        linger : int, optional
            If specified, set LINGER on sockets prior to closing them.
        """
        self.destroy(linger)
        self._ctx = zmq.Context()


class ServerContext(BaseContext):
    """CurveZMQ server context

    We create server sockets via the `ctx.socket` method, which automatically
    applies the necessary certificates and socket options.

    We handle client certificate authentication in a separate dedicated thread.

    Parameters
    ----------
    cert_dir : str | os.PathLike | None
        Path to the certificate directory. Setting this to None will disable encryption.
    """

    def __init__(self, cert_dir: Optional[Union[str, os.PathLike]]) -> None:
        super().__init__(cert_dir)
        self.auth_thread = None
        if self.encrypted:
            self.auth_thread = self._start_auth_thread()

    def __del__(self):
        # Avoid issues in which the auth_thread attr was
        # previously deleted
        if getattr(self, "auth_thread", None):
            self.auth_thread.stop()

    def _start_auth_thread(self) -> ThreadAuthenticator:
        auth_thread = ThreadAuthenticator(self._ctx)
        auth_thread.start()
        # Only allow certs that are in the cert dir
        assert self.cert_dir  # For mypy
        auth_thread.configure_curve(domain="*", location=self.cert_dir)
        return auth_thread

    def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
        sock = self._ctx.socket(socket_type, *args, **kwargs)
        if self.encrypted:
            assert self.cert_dir  # For mypy
            _, secret_key = _load_certificate(self.cert_dir, name="server")
            try:
                # Only the clients need the server's public key to encrypt
                # messages and verify the server's identity.
                # Ref: http://curvezmq.org/page:read-the-docs
                sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key)
            except zmq.ZMQError as e:
                raise ValueError("Invalid CurveZMQ key format") from e
            sock.setsockopt(zmq.CURVE_SERVER, True)  # Must come before bind

        # This flag enables IPV6 in addition to IPV4
        sock.setsockopt(zmq.IPV6, True)
        return sock

    def term(self):
        if self.auth_thread:
            self.auth_thread.stop()
        super().term()

    def destroy(self, linger: Optional[int] = None):
        if self.auth_thread:
            self.auth_thread.stop()
        super().destroy(linger)

    def recreate(self, linger: Optional[int] = None):
        super().recreate(linger)
        if self.auth_thread:
            self.auth_thread = self._start_auth_thread()


class ClientContext(BaseContext):
    """CurveZMQ client context

    We create client sockets via the `ctx.socket` method, which automatically
    applies the necessary certificates and socket options.

    Parameters
    ----------
    cert_dir : str | os.PathLike | None
        Path to the certificate directory. Setting this to None will disable encryption.
    """

    def socket(self, socket_type: int, *args, **kwargs) -> zmq.Socket:
        sock = self._ctx.socket(socket_type, *args, **kwargs)
        if self.encrypted:
            assert self.cert_dir  # For mypy
            public_key, secret_key = _load_certificate(self.cert_dir, name="client")
            server_public_key, _ = _load_certificate(self.cert_dir, name="server")
            try:
                sock.setsockopt(zmq.CURVE_PUBLICKEY, public_key)
                sock.setsockopt(zmq.CURVE_SECRETKEY, secret_key)
                sock.setsockopt(zmq.CURVE_SERVERKEY, server_public_key)
            except zmq.ZMQError as e:
                raise ValueError("Invalid CurveZMQ key format") from e
        sock.setsockopt(zmq.IPV6, True)
        return sock