File: kv.py

package info (click to toggle)
python-msgspec 0.19.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 6,356 kB
  • sloc: javascript: 23,944; ansic: 20,540; python: 20,465; makefile: 29; sh: 19
file content (171 lines) | stat: -rw-r--r-- 5,653 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from __future__ import annotations

import asyncio
import msgspec
from typing import Any


# Some utilities for writing and reading length-prefix framed messages. Using
# length-prefixed framing makes it easier for the reader to determine the
# boundaries of each message before passing it to msgspec to be decoded.
async def prefixed_send(stream: asyncio.StreamWriter, buffer: bytes) -> None:
    """Write a length-prefixed buffer to the stream"""
    # Encode the message length as a 4 byte big-endian integer.
    prefix = len(buffer).to_bytes(4, "big")

    # Write the prefix and buffer to the stream.
    stream.write(prefix)
    stream.write(buffer)
    await stream.drain()


async def prefixed_recv(stream: asyncio.StreamReader) -> bytes:
    """Read a length-prefixed buffer from the stream"""
    # Read the next 4 byte prefix
    prefix = await stream.readexactly(4)

    # Convert the prefix back into an integer for the next message length
    n = int.from_bytes(prefix, "big")

    # Read in the full message buffer
    return await stream.readexactly(n)


# Define some request types. We set `tag=True` on each type so they can be used
# in a "tagged-union" defining the request types.
class Get(msgspec.Struct, tag=True):
    key: str


class Put(msgspec.Struct, tag=True):
    key: str
    val: str


class Del(msgspec.Struct, tag=True):
    key: str


class ListKeys(msgspec.Struct, tag=True):
    pass


# A union of all valid request types
Request = Get | Put | Del | ListKeys


class Server:
    """An example TCP key-value server using asyncio and msgspec"""

    def __init__(self, host: str = "127.0.0.1", port: int = 8888):
        self.host = host
        self.port = port
        self.kv: dict[str, str] = {}
        # A msgpack encoder for encoding responses
        self.encoder = msgspec.msgpack.Encoder()
        # A *typed* msgpack decoder for decoding requests. If a request doesn't
        # match the specified types, a nice error will be raised.
        self.decoder = msgspec.msgpack.Decoder(Request)

    async def handle_connection(
        self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
    ):
        """Handle the full lifetime of a single connection"""
        print("Connection opened")
        while True:
            try:
                # Receive and decode a request
                buffer = await prefixed_recv(reader)
                req = self.decoder.decode(buffer)

                # Process the request
                resp = await self.handle_request(req)

                # Encode and write the response
                buffer = self.encoder.encode(resp)
                await prefixed_send(writer, buffer)
            except EOFError:
                print("Connection closed")
                return

    async def handle_request(self, req: Request) -> Any:
        """Handle a single request and return the result (if any)"""
        # We use pattern matching here to branch on the different message types.
        # You could just as well use an if-else statement, but pattern matching
        # works pretty well here.
        match req:
            case Get(key):
                # Return the value for a key, or None if missing
                return self.kv.get(key)
            case Put(key, val):
                # Add a new key-value pair
                self.kv[key] = val
                return None
            case Del(key):
                # Remove a key-value pair if it exists
                self.kv.pop(key, None)
                return None
            case ListKeys():
                # Return a list of all keys in the store
                return sorted(self.kv)

    async def serve_forever(self) -> None:
        server = await asyncio.start_server(
            self.handle_connection, self.host, self.port
        )
        print(f"Serving on tcp://{self.host}:{self.port}...")
        async with server:
            await server.serve_forever()

    def run(self) -> None:
        """Run the server until ctrl-C"""
        asyncio.run(self.serve_forever())


class Client:
    """An example TCP key-value client using asyncio and msgspec."""

    def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        self.reader = reader
        self.writer = writer

    @classmethod
    async def create(cls, host: str = "127.0.0.1", port: int = 8888):
        """Create a new client"""
        reader, writer = await asyncio.open_connection(host, port)
        return cls(reader, writer)

    async def close(self) -> None:
        """Close the client."""
        self.writer.close()
        await self.writer.wait_closed()

    async def request(self, req):
        """Send a request and await the response"""
        # Encode and send the request
        buffer = msgspec.msgpack.encode(req)
        await prefixed_send(self.writer, buffer)

        # Receive and decode the response
        buffer = await prefixed_recv(self.reader)
        return msgspec.msgpack.decode(buffer)

    async def get(self, key: str) -> str | None:
        """Get a key from the KV store, returning None if not present"""
        return await self.request(Get(key))

    async def put(self, key: str, val: str) -> None:
        """Put a key-val pair in the KV store"""
        return await self.request(Put(key, val))

    async def delete(self, key: str) -> None:
        """Delete a key-val pair from the KV store"""
        return await self.request(Del(key))

    async def list_keys(self) -> list[str]:
        """List all keys in the KV store"""
        return await self.request(ListKeys())


if __name__ == "__main__":
    Server().run()