1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
from typing import Tuple, Any, Dict, Callable, List, Iterable
from fakeredis import _msgs as msgs
from fakeredis._commands import command
from fakeredis._helpers import NoResponse, compile_pattern, SimpleError
class PubSubCommandsMixin:
put_response: Callable[[Any], None]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(PubSubCommandsMixin, self).__init__(*args, **kwargs)
self._pubsub = 0 # Count of subscriptions
self._server: Any
self.version: Tuple[int]
def _subscribe(self, channels: Iterable[bytes], subscribers: Dict[bytes, Any], mtype: bytes) -> NoResponse:
for channel in channels:
subs = subscribers[channel]
if self not in subs:
subs.add(self)
self._pubsub += 1
msg = [mtype, channel, self._pubsub]
self.put_response(msg)
return NoResponse()
def _unsubscribe(self, channels: Iterable[bytes], subscribers: Dict[bytes, Any], mtype: bytes) -> NoResponse:
if not channels:
channels = []
for channel, subs in subscribers.items():
if self in subs:
channels.append(channel)
for channel in channels:
subs = subscribers.get(channel, set())
if self in subs:
subs.remove(self)
if not subs:
del subscribers[channel]
self._pubsub -= 1
msg = [mtype, channel, self._pubsub]
self.put_response(msg)
return NoResponse()
def _numsub(self, subscribers: Dict[bytes, Any], *channels: bytes) -> List[Any]:
tuples_list = [(ch, len(subscribers.get(ch, []))) for ch in channels]
return [item for sublist in tuples_list for item in sublist]
@command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def psubscribe(self, *patterns: bytes) -> NoResponse:
return self._subscribe(patterns, self._server.psubscribers, b"psubscribe")
@command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def subscribe(self, *channels: bytes) -> NoResponse:
return self._subscribe(channels, self._server.subscribers, b"subscribe")
@command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def ssubscribe(self, *channels: bytes) -> NoResponse:
return self._subscribe(channels, self._server.ssubscribers, b"ssubscribe")
@command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def punsubscribe(self, *patterns: bytes) -> NoResponse:
return self._unsubscribe(patterns, self._server.psubscribers, b"punsubscribe")
@command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def unsubscribe(self, *channels: bytes) -> NoResponse:
return self._unsubscribe(channels, self._server.subscribers, b"unsubscribe")
@command(fixed=(), repeat=(bytes,), flags=msgs.FLAG_NO_SCRIPT)
def sunsubscribe(self, *channels: bytes) -> NoResponse:
return self._unsubscribe(channels, self._server.ssubscribers, b"sunsubscribe")
@command((bytes, bytes))
def publish(self, channel: bytes, message: bytes) -> int:
receivers = 0
msg = [b"message", channel, message]
subs = self._server.subscribers.get(channel, set())
for sock in subs:
sock.put_response(msg)
receivers += 1
for pattern, socks in self._server.psubscribers.items():
regex = compile_pattern(pattern)
if regex.match(channel):
msg = [b"pmessage", pattern, channel, message]
for sock in socks:
sock.put_response(msg)
receivers += 1
return receivers
@command((bytes, bytes))
def spublish(self, channel: bytes, message: bytes) -> int:
receivers = 0
msg = [b"smessage", channel, message]
subs = self._server.ssubscribers.get(channel, set())
for sock in subs:
sock.put_response(msg)
receivers += 1
for pattern, socks in self._server.psubscribers.items():
regex = compile_pattern(pattern)
if regex.match(channel):
msg = [b"pmessage", pattern, channel, message]
for sock in socks:
sock.put_response(msg)
receivers += 1
return receivers
@command(name="PUBSUB NUMPAT", fixed=(), repeat=())
def pubsub_numpat(self, *_: Any) -> int:
return len(self._server.psubscribers)
def _channels(self, subscribers_dict: Dict[bytes, Any], *patterns: bytes) -> List[bytes]:
channels = list(subscribers_dict.keys())
if len(patterns) > 0:
regex = compile_pattern(patterns[0])
channels = [ch for ch in channels if regex.match(ch)]
return channels
@command(name="PUBSUB CHANNELS", fixed=(), repeat=(bytes,))
def pubsub_channels(self, *args: bytes) -> List[bytes]:
return self._channels(self._server.subscribers, *args)
@command(name="PUBSUB SHARDCHANNELS", fixed=(), repeat=(bytes,))
def pubsub_shardchannels(self, *args: bytes) -> List[bytes]:
return self._channels(self._server.ssubscribers, *args)
@command(name="PUBSUB NUMSUB", fixed=(), repeat=(bytes,))
def pubsub_numsub(self, *args: bytes) -> List[Any]:
return self._numsub(self._server.subscribers, *args)
@command(name="PUBSUB SHARDNUMSUB", fixed=(), repeat=(bytes,))
def pubsub_shardnumsub(self, *args: bytes) -> List[Any]:
return self._numsub(self._server.ssubscribers, *args)
@command(name="PUBSUB", fixed=())
def pubsub(self, *args: Any) -> None:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("pubsub"))
@command(name="PUBSUB HELP", fixed=())
def pubsub_help(self, *args: Any) -> List[bytes]:
if self.version >= (7,):
help_strings = [
"PUBSUB <subcommand> [<arg> [value] [opt] ...]. Subcommands are:",
"CHANNELS [<pattern>]",
" Return the currently active channels matching a <pattern> (default: '*').",
"NUMPAT",
" Return number of subscriptions to patterns.",
"NUMSUB [<channel> ...]",
" Return the number of subscribers for the specified channels, excluding",
" pattern subscriptions(default: no channels).",
"SHARDCHANNELS [<pattern>]",
" Return the currently active shard level channels matching a <pattern> (default: '*').",
"SHARDNUMSUB [<shardchannel> ...]",
" Return the number of subscribers for the specified shard level channel(s)",
"HELP",
(" Prints this help." if self.version < (7, 1) else " Print this help."),
]
else:
help_strings = [
"PUBSUB <subcommand> [<arg> [value] [opt] ...]. Subcommands are:",
"CHANNELS [<pattern>]",
" Return the currently active channels matching a <pattern> (default: '*').",
"NUMPAT",
" Return number of subscriptions to patterns.",
"NUMSUB [<channel> ...]",
" Return the number of subscribers for the specified channels, excluding",
" pattern subscriptions(default: no channels).",
"HELP",
" Prints this help.",
]
return [s.encode() for s in help_strings]
|