File: pubsub_mixin.py

package info (click to toggle)
python-fakeredis 2.29.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,772 kB
  • sloc: python: 19,002; sh: 8; makefile: 5
file content (169 lines) | stat: -rw-r--r-- 7,511 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from typing import Tuple, Any, Dict, Callable, List, Iterable

from fakeredis import _msgs as msgs
from fakeredis._commands import command
from fakeredis._helpers import NoResponse, compile_pattern, SimpleError


class PubSubCommandsMixin:
    put_response: Callable[[Any], None]

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(PubSubCommandsMixin, self).__init__(*args, **kwargs)
        self._pubsub = 0  # Count of subscriptions
        self._server: Any
        self.version: Tuple[int]

    def _subscribe(self, channels: Iterable[bytes], subscribers: Dict[bytes, Any], mtype: bytes) -> NoResponse:
        for channel in channels:
            subs = subscribers[channel]
            if self not in subs:
                subs.add(self)
                self._pubsub += 1
            msg = [mtype, channel, self._pubsub]
            self.put_response(msg)
        return NoResponse()

    def _unsubscribe(self, channels: Iterable[bytes], subscribers: Dict[bytes, Any], mtype: bytes) -> NoResponse:
        if not channels:
            channels = []
            for channel, subs in subscribers.items():
                if self in subs:
                    channels.append(channel)
        for channel in channels:
            subs = subscribers.get(channel, set())
            if self in subs:
                subs.remove(self)
                if not subs:
                    del subscribers[channel]
                self._pubsub -= 1
            msg = [mtype, channel, self._pubsub]
            self.put_response(msg)
        return NoResponse()

    def _numsub(self, subscribers: Dict[bytes, Any], *channels: bytes) -> List[Any]:
        tuples_list = [(ch, len(subscribers.get(ch, []))) for ch in channels]
        return [item for sublist in tuples_list for item in sublist]

    @command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
    def psubscribe(self, *patterns: bytes) -> NoResponse:
        return self._subscribe(patterns, self._server.psubscribers, b"psubscribe")

    @command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
    def subscribe(self, *channels: bytes) -> NoResponse:
        return self._subscribe(channels, self._server.subscribers, b"subscribe")

    @command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
    def ssubscribe(self, *channels: bytes) -> NoResponse:
        return self._subscribe(channels, self._server.ssubscribers, b"ssubscribe")

    @command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
    def punsubscribe(self, *patterns: bytes) -> NoResponse:
        return self._unsubscribe(patterns, self._server.psubscribers, b"punsubscribe")

    @command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
    def unsubscribe(self, *channels: bytes) -> NoResponse:
        return self._unsubscribe(channels, self._server.subscribers, b"unsubscribe")

    @command(fixed=(), repeat=(bytes,), flags=msgs.FLAG_NO_SCRIPT)
    def sunsubscribe(self, *channels: bytes) -> NoResponse:
        return self._unsubscribe(channels, self._server.ssubscribers, b"sunsubscribe")

    @command((bytes, bytes))
    def publish(self, channel: bytes, message: bytes) -> int:
        receivers = 0
        msg = [b"message", channel, message]
        subs = self._server.subscribers.get(channel, set())
        for sock in subs:
            sock.put_response(msg)
            receivers += 1
        for pattern, socks in self._server.psubscribers.items():
            regex = compile_pattern(pattern)
            if regex.match(channel):
                msg = [b"pmessage", pattern, channel, message]
                for sock in socks:
                    sock.put_response(msg)
                    receivers += 1
        return receivers

    @command((bytes, bytes))
    def spublish(self, channel: bytes, message: bytes) -> int:
        receivers = 0
        msg = [b"smessage", channel, message]
        subs = self._server.ssubscribers.get(channel, set())
        for sock in subs:
            sock.put_response(msg)
            receivers += 1
        for pattern, socks in self._server.psubscribers.items():
            regex = compile_pattern(pattern)
            if regex.match(channel):
                msg = [b"pmessage", pattern, channel, message]
                for sock in socks:
                    sock.put_response(msg)
                    receivers += 1
        return receivers

    @command(name="PUBSUB NUMPAT", fixed=(), repeat=())
    def pubsub_numpat(self, *_: Any) -> int:
        return len(self._server.psubscribers)

    def _channels(self, subscribers_dict: Dict[bytes, Any], *patterns: bytes) -> List[bytes]:
        channels = list(subscribers_dict.keys())
        if len(patterns) > 0:
            regex = compile_pattern(patterns[0])
            channels = [ch for ch in channels if regex.match(ch)]
        return channels

    @command(name="PUBSUB CHANNELS", fixed=(), repeat=(bytes,))
    def pubsub_channels(self, *args: bytes) -> List[bytes]:
        return self._channels(self._server.subscribers, *args)

    @command(name="PUBSUB SHARDCHANNELS", fixed=(), repeat=(bytes,))
    def pubsub_shardchannels(self, *args: bytes) -> List[bytes]:
        return self._channels(self._server.ssubscribers, *args)

    @command(name="PUBSUB NUMSUB", fixed=(), repeat=(bytes,))
    def pubsub_numsub(self, *args: bytes) -> List[Any]:
        return self._numsub(self._server.subscribers, *args)

    @command(name="PUBSUB SHARDNUMSUB", fixed=(), repeat=(bytes,))
    def pubsub_shardnumsub(self, *args: bytes) -> List[Any]:
        return self._numsub(self._server.ssubscribers, *args)

    @command(name="PUBSUB", fixed=())
    def pubsub(self, *args: Any) -> None:
        raise SimpleError(msgs.WRONG_ARGS_MSG6.format("pubsub"))

    @command(name="PUBSUB HELP", fixed=())
    def pubsub_help(self, *args: Any) -> List[bytes]:
        if self.version >= (7,):
            help_strings = [
                "PUBSUB <subcommand> [<arg> [value] [opt] ...]. Subcommands are:",
                "CHANNELS [<pattern>]",
                "    Return the currently active channels matching a <pattern> (default: '*').",
                "NUMPAT",
                "    Return number of subscriptions to patterns.",
                "NUMSUB [<channel> ...]",
                "    Return the number of subscribers for the specified channels, excluding",
                "    pattern subscriptions(default: no channels).",
                "SHARDCHANNELS [<pattern>]",
                "    Return the currently active shard level channels matching a <pattern> (default: '*').",
                "SHARDNUMSUB [<shardchannel> ...]",
                "    Return the number of subscribers for the specified shard level channel(s)",
                "HELP",
                ("    Prints this help." if self.version < (7, 1) else "    Print this help."),
            ]
        else:
            help_strings = [
                "PUBSUB <subcommand> [<arg> [value] [opt] ...]. Subcommands are:",
                "CHANNELS [<pattern>]",
                "    Return the currently active channels matching a <pattern> (default: '*').",
                "NUMPAT",
                "    Return number of subscriptions to patterns.",
                "NUMSUB [<channel> ...]",
                "    Return the number of subscribers for the specified channels, excluding",
                "    pattern subscriptions(default: no channels).",
                "HELP",
                "    Prints this help.",
            ]
        return [s.encode() for s in help_strings]