1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
|
import asyncio
import fnmatch
import random
import re
import string
import time
import warnings
from copy import deepcopy
from django.conf import settings
from django.core.signals import setting_changed
from django.utils.module_loading import import_string
from channels import DEFAULT_CHANNEL_LAYER
from .exceptions import ChannelFull, InvalidChannelLayerError
class ChannelLayerManager:
"""
Takes a settings dictionary of backends and initialises them on request.
"""
def __init__(self):
self.backends = {}
setting_changed.connect(self._reset_backends)
def _reset_backends(self, setting, **kwargs):
"""
Removes cached channel layers when the CHANNEL_LAYERS setting changes.
"""
if setting == "CHANNEL_LAYERS":
self.backends = {}
@property
def configs(self):
# Lazy load settings so we can be imported
return getattr(settings, "CHANNEL_LAYERS", {})
def make_backend(self, name):
"""
Instantiate channel layer.
"""
config = self.configs[name].get("CONFIG", {})
return self._make_backend(name, config)
def make_test_backend(self, name):
"""
Instantiate channel layer using its test config.
"""
try:
config = self.configs[name]["TEST_CONFIG"]
except KeyError:
raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name)
return self._make_backend(name, config)
def _make_backend(self, name, config):
# Check for old format config
if "ROUTING" in self.configs[name]:
raise InvalidChannelLayerError(
"ROUTING key found for %s - this is no longer needed in Channels 2."
% name
)
# Load the backend class
try:
backend_module = self.configs[name]["BACKEND"]
except KeyError:
raise InvalidChannelLayerError("No BACKEND specified for %s" % name)
else:
try:
backend_class = import_string(backend_module)
except ImportError:
raise InvalidChannelLayerError(
"Cannot import BACKEND %r specified for %s"
% (self.configs[name]["BACKEND"], name)
)
# Initialise and pass config
return backend_class(**config)
def __getitem__(self, key):
if key not in self.backends:
self.backends[key] = self.make_backend(key)
return self.backends[key]
def __contains__(self, key):
return key in self.configs
def set(self, key, layer):
"""
Sets an alias to point to a new ChannelLayerWrapper instance, and
returns the old one that it replaced. Useful for swapping out the
backend during tests.
"""
old = self.backends.get(key, None)
self.backends[key] = layer
return old
class BaseChannelLayer:
"""
Base channel layer class that others can inherit from, with useful
common functionality.
"""
MAX_NAME_LENGTH = 100
def __init__(self, expiry=60, capacity=100, channel_capacity=None):
self.expiry = expiry
self.capacity = capacity
self.channel_capacity = channel_capacity or {}
def compile_capacities(self, channel_capacity):
"""
Takes an input channel_capacity dict and returns the compiled list
of regexes that get_capacity will look for as self.channel_capacity
"""
result = []
for pattern, value in channel_capacity.items():
# If they passed in a precompiled regex, leave it, else interpret
# it as a glob.
if hasattr(pattern, "match"):
result.append((pattern, value))
else:
result.append((re.compile(fnmatch.translate(pattern)), value))
return result
def get_capacity(self, channel):
"""
Gets the correct capacity for the given channel; either the default,
or a matching result from channel_capacity. Returns the first matching
result; if you want to control the order of matches, use an ordered dict
as input.
"""
for pattern, capacity in self.channel_capacity:
if pattern.match(channel):
return capacity
return self.capacity
def match_type_and_length(self, name):
if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH):
return True
return False
# Name validation functions
channel_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+(\![\d\w\-_.]*)?$")
group_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+$")
invalid_name_error = (
"{} name must be a valid unicode string "
+ "with length < {} ".format(MAX_NAME_LENGTH)
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods."
)
def require_valid_channel_name(self, name, receive=False):
if not self.match_type_and_length(name):
raise TypeError(self.invalid_name_error.format("Channel"))
if not bool(self.channel_name_regex.match(name)):
raise TypeError(self.invalid_name_error.format("Channel"))
if "!" in name and not name.endswith("!") and receive:
raise TypeError("Specific channel names in receive() must end at the !")
return True
def require_valid_group_name(self, name):
if not self.match_type_and_length(name):
raise TypeError(self.invalid_name_error.format("Group"))
if not bool(self.group_name_regex.match(name)):
raise TypeError(self.invalid_name_error.format("Group"))
return True
def valid_channel_names(self, names, receive=False):
_non_empty_list = True if names else False
_names_type = isinstance(names, list)
assert _non_empty_list and _names_type, "names must be a non-empty list"
for channel in names:
self.require_valid_channel_name(channel, receive=receive)
return True
def non_local_name(self, name):
"""
Given a channel name, returns the "non-local" part. If the channel name
is a process-specific channel (contains !) this means the part up to
and including the !; if it is anything else, this means the full name.
"""
if "!" in name:
return name[: name.find("!") + 1]
else:
return name
async def send(self, channel, message):
raise NotImplementedError("send() should be implemented in a channel layer")
async def receive(self, channel):
raise NotImplementedError("receive() should be implemented in a channel layer")
async def new_channel(self):
raise NotImplementedError(
"new_channel() should be implemented in a channel layer"
)
async def flush(self):
raise NotImplementedError("flush() not implemented (flush extension)")
async def group_add(self, group, channel):
raise NotImplementedError("group_add() not implemented (groups extension)")
async def group_discard(self, group, channel):
raise NotImplementedError("group_discard() not implemented (groups extension)")
async def group_send(self, group, message):
raise NotImplementedError("group_send() not implemented (groups extension)")
# Deprecated methods.
def valid_channel_name(self, channel_name, receive=False):
"""
Deprecated: Use require_valid_channel_name instead.
"""
warnings.warn(
"valid_channel_name is deprecated, use require_valid_channel_name instead.",
DeprecationWarning,
stacklevel=2,
)
return self.require_valid_channel_name(channel_name)
def valid_group_name(self, group_name):
"""
Deprecated: Use require_valid_group_name instead..
"""
warnings.warn(
"valid_group_name is deprecated, use require_valid_group_name instead.",
DeprecationWarning,
stacklevel=2,
)
return self.require_valid_group_name(group_name)
class InMemoryChannelLayer(BaseChannelLayer):
"""
In-memory channel layer implementation
"""
def __init__(
self,
expiry=60,
group_expiry=86400,
capacity=100,
channel_capacity=None,
**kwargs,
):
super().__init__(
expiry=expiry,
capacity=capacity,
channel_capacity=channel_capacity,
**kwargs,
)
self.channels = {}
self.groups = {}
self.group_expiry = group_expiry
# Channel layer API
extensions = ["groups", "flush"]
async def send(self, channel, message):
"""
Send a message onto a (general or specific) channel.
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict"
self.require_valid_channel_name(channel)
# If it's a process-local channel, strip off local part and stick full
# name in message
assert "__asgi_channel__" not in message
queue = self.channels.setdefault(
channel, asyncio.Queue(maxsize=self.get_capacity(channel))
)
# Add message
try:
queue.put_nowait((time.time() + self.expiry, deepcopy(message)))
except asyncio.queues.QueueFull:
raise ChannelFull(channel)
async def receive(self, channel):
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
self.require_valid_channel_name(channel)
self._clean_expired()
queue = self.channels.setdefault(
channel, asyncio.Queue(maxsize=self.get_capacity(channel))
)
# Do a plain direct receive
try:
_, message = await queue.get()
finally:
if queue.empty():
self.channels.pop(channel, None)
return message
async def new_channel(self, prefix="specific."):
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
return "%s.inmemory!%s" % (
prefix,
"".join(random.choice(string.ascii_letters) for i in range(12)),
)
# Expire cleanup
def _clean_expired(self):
"""
Goes through all messages and groups and removes those that are expired.
Any channel with an expired message is removed from all groups.
"""
# Channel cleanup
for channel, queue in list(self.channels.items()):
# See if it's expired
while not queue.empty() and queue._queue[0][0] < time.time():
queue.get_nowait()
# Any removal prompts group discard
self._remove_from_groups(channel)
# Is the channel now empty and needs deleting?
if queue.empty():
self.channels.pop(channel, None)
# Group Expiration
timeout = int(time.time()) - self.group_expiry
for channels in self.groups.values():
for name, timestamp in list(channels.items()):
# If join time is older than group_expiry
# end the group membership
if timestamp and timestamp < timeout:
# Delete from group
channels.pop(name, None)
# Flush extension
async def flush(self):
self.channels = {}
self.groups = {}
async def close(self):
# Nothing to go
pass
def _remove_from_groups(self, channel):
"""
Removes a channel from all groups. Used when a message on it expires.
"""
for channels in self.groups.values():
channels.pop(channel, None)
# Groups extension
async def group_add(self, group, channel):
"""
Adds the channel name to a group.
"""
# Check the inputs
self.require_valid_group_name(group)
self.require_valid_channel_name(channel)
# Add to group dict
self.groups.setdefault(group, {})
self.groups[group][channel] = time.time()
async def group_discard(self, group, channel):
# Both should be text and valid
self.require_valid_channel_name(channel)
self.require_valid_group_name(group)
# Remove from group set
group_channels = self.groups.get(group, None)
if group_channels:
# remove channel if in group
group_channels.pop(channel, None)
# is group now empty? If yes remove it
if not group_channels:
self.groups.pop(group, None)
async def group_send(self, group, message):
# Check types
assert isinstance(message, dict), "Message is not a dict"
self.require_valid_group_name(group)
# Run clean
self._clean_expired()
# Send to each channel
ops = []
if group in self.groups:
for channel in self.groups[group].keys():
ops.append(asyncio.create_task(self.send(channel, message)))
for send_result in asyncio.as_completed(ops):
try:
await send_result
except ChannelFull:
pass
def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER):
"""
Returns a channel layer by alias, or None if it is not configured.
"""
try:
return channel_layers[alias]
except KeyError:
return None
# Default global instance of the channel layer manager
channel_layers = ChannelLayerManager()
|