From: Devid <setti.davide89@gmail.com>
Date: Tue, 28 Mar 2023 11:30:35 +0200
Subject: Assured pools are closed on loop close in core (#347)

Origin: upstream, https://github.com/django/channels_redis/pull/347
Bug: https://github.com/django/channels_redis/issues/332
Bug-Debian: https://bugs.debian.org/1027387
Last-Update: 2024-02-19
---
 channels_redis/core.py   | 64 ++++++++++++++++++++++++++++--------------------
 channels_redis/pubsub.py | 18 +-------------
 channels_redis/utils.py  | 16 ++++++++++++
 3 files changed, 54 insertions(+), 44 deletions(-)

diff --git a/channels_redis/core.py b/channels_redis/core.py
index 7c04ecd..c3eb3b3 100644
--- a/channels_redis/core.py
+++ b/channels_redis/core.py
@@ -15,7 +15,7 @@ from redis import asyncio as aioredis
 from channels.exceptions import ChannelFull
 from channels.layers import BaseChannelLayer
 
-from .utils import _consistent_hash
+from .utils import _consistent_hash, _wrap_close
 
 logger = logging.getLogger(__name__)
 
@@ -69,6 +69,26 @@ class BoundedQueue(asyncio.Queue):
         return super(BoundedQueue, self).put_nowait(item)
 
 
+class RedisLoopLayer:
+    def __init__(self, channel_layer):
+        self._lock = asyncio.Lock()
+        self.channel_layer = channel_layer
+        self._connections = {}
+
+    def get_connection(self, index):
+        if index not in self._connections:
+            pool = self.channel_layer.create_pool(index)
+            self._connections[index] = aioredis.Redis(connection_pool=pool)
+
+        return self._connections[index]
+
+    async def flush(self):
+        async with self._lock:
+            for index in list(self._connections):
+                connection = self._connections.pop(index)
+                await connection.close(close_connection_pool=True)
+
+
 class RedisChannelLayer(BaseChannelLayer):
     """
     Redis channel layer.
@@ -101,8 +121,7 @@ class RedisChannelLayer(BaseChannelLayer):
         self.hosts = self.decode_hosts(hosts)
         self.ring_size = len(self.hosts)
         # Cached redis connection pools and the event loop they are from
-        self.pools = {}
-        self.pools_loop = None
+        self._layers = {}
         # Normal channels choose a host index by cycling through the available hosts
         self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
         self._send_index_generator = itertools.cycle(range(len(self.hosts)))
@@ -138,7 +157,7 @@ class RedisChannelLayer(BaseChannelLayer):
             return aioredis.sentinel.SentinelConnectionPool(
                 master_name,
                 aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
-                **host
+                **host,
             )
         else:
             return aioredis.ConnectionPool(**host)
@@ -331,7 +350,7 @@ class RedisChannelLayer(BaseChannelLayer):
 
                         raise
 
-                    message, token, exception = None, None, None
+                    message = token = exception = None
                     for task in done:
                         try:
                             result = task.result()
@@ -367,7 +386,7 @@ class RedisChannelLayer(BaseChannelLayer):
                             message_channel, message = await self.receive_single(
                                 real_channel
                             )
-                            if type(message_channel) is list:
+                            if isinstance(message_channel, list):
                                 for chan in message_channel:
                                     self.receive_buffer[chan].put_nowait(message)
                             else:
@@ -459,11 +478,7 @@ class RedisChannelLayer(BaseChannelLayer):
         Returns a new channel name that can be used by something in our
         process as a specific channel.
         """
-        return "%s.%s!%s" % (
-            prefix,
-            self.client_prefix,
-            uuid.uuid4().hex,
-        )
+        return f"{prefix}.{self.client_prefix}!{uuid.uuid4().hex}"
 
     ### Flush extension ###
 
@@ -496,9 +511,8 @@ class RedisChannelLayer(BaseChannelLayer):
         # Flush all cleaners, in case somebody just wanted to close the
         # pools without flushing first.
         await self.wait_received()
-
-        for index in self.pools:
-            await self.pools[index].disconnect()
+        for layer in self._layers.values():
+            await layer.flush()
 
     async def wait_received(self):
         """
@@ -667,7 +681,7 @@ class RedisChannelLayer(BaseChannelLayer):
         """
         Common function to make the storage key for the group.
         """
-        return ("%s:group:%s" % (self.prefix, group)).encode("utf8")
+        return f"{self.prefix}:group:{group}".encode("utf8")
 
     ### Serialization ###
 
@@ -711,7 +725,7 @@ class RedisChannelLayer(BaseChannelLayer):
         return Fernet(formatted_key)
 
     def __str__(self):
-        return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)
+        return f"{self.__class__.__name__}(hosts={self.hosts})"
 
     ### Connection handling ###
 
@@ -723,18 +737,14 @@ class RedisChannelLayer(BaseChannelLayer):
         # Catch bad indexes
         if not 0 <= index < self.ring_size:
             raise ValueError(
-                "There are only %s hosts - you asked for %s!" % (self.ring_size, index)
+                f"There are only {self.ring_size} hosts - you asked for {index}!"
             )
 
+        loop = asyncio.get_running_loop()
         try:
-            loop = asyncio.get_running_loop()
-            if self.pools_loop != loop:
-                self.pools = {}
-                self.pools_loop = loop
-        except RuntimeError:
-            pass
-
-        if index not in self.pools:
-            self.pools[index] = self.create_pool(index)
+            layer = self._layers[loop]
+        except KeyError:
+            _wrap_close(self, loop)
+            layer = self._layers[loop] = RedisLoopLayer(self)
 
-        return aioredis.Redis(connection_pool=self.pools[index])
+        return layer.get_connection(index)
diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py
index 3c10378..550325b 100644
--- a/channels_redis/pubsub.py
+++ b/channels_redis/pubsub.py
@@ -1,32 +1,16 @@
 import asyncio
 import functools
 import logging
-import types
 import uuid
 
 import msgpack
 from redis import asyncio as aioredis
 
-from .utils import _consistent_hash
+from .utils import _consistent_hash, _wrap_close
 
 logger = logging.getLogger(__name__)
 
 
-def _wrap_close(proxy, loop):
-    original_impl = loop.close
-
-    def _wrapper(self, *args, **kwargs):
-        if loop in proxy._layers:
-            layer = proxy._layers[loop]
-            del proxy._layers[loop]
-            loop.run_until_complete(layer.flush())
-
-        self.close = original_impl
-        return self.close(*args, **kwargs)
-
-    loop.close = types.MethodType(_wrapper, loop)
-
-
 async def _async_proxy(obj, name, *args, **kwargs):
     # Must be defined as a function and not a method due to
     # https://bugs.python.org/issue38364
diff --git a/channels_redis/utils.py b/channels_redis/utils.py
index 7b30fdc..d2405bb 100644
--- a/channels_redis/utils.py
+++ b/channels_redis/utils.py
@@ -1,4 +1,5 @@
 import binascii
+import types
 
 
 def _consistent_hash(value, ring_size):
@@ -15,3 +16,18 @@ def _consistent_hash(value, ring_size):
     bigval = binascii.crc32(value) & 0xFFF
     ring_divisor = 4096 / float(ring_size)
     return int(bigval / ring_divisor)
+
+
+def _wrap_close(proxy, loop):
+    original_impl = loop.close
+
+    def _wrapper(self, *args, **kwargs):
+        if loop in proxy._layers:
+            layer = proxy._layers[loop]
+            del proxy._layers[loop]
+            loop.run_until_complete(layer.flush())
+
+        self.close = original_impl
+        return self.close(*args, **kwargs)
+
+    loop.close = types.MethodType(_wrapper, loop)
