From: Chris Lamb <lamby@debian.org>
Date: Sun, 14 Jun 2020 12:08:06 +0100
Subject: CVE-2020-13254

---
 django/core/cache/__init__.py           |  3 ++-
 django/core/cache/backends/base.py      | 33 +++++++++++++++++++++------------
 django/core/cache/backends/memcached.py | 22 ++++++++++++++++++++--
 3 files changed, 43 insertions(+), 15 deletions(-)

diff --git a/django/core/cache/__init__.py b/django/core/cache/__init__.py
index cd2bb43..bc1e4de 100644
--- a/django/core/cache/__init__.py
+++ b/django/core/cache/__init__.py
@@ -18,12 +18,13 @@ from django.conf import settings
 from django.core import signals
 from django.core.cache.backends.base import (
     BaseCache, CacheKeyWarning, InvalidCacheBackendError,
+    InvalidCacheKey,
 )
 from django.utils.module_loading import import_string
 
 __all__ = [
     'cache', 'DEFAULT_CACHE_ALIAS', 'InvalidCacheBackendError',
-    'CacheKeyWarning', 'BaseCache',
+    'CacheKeyWarning', 'BaseCache', 'InvalidCacheKey',
 ]
 
 DEFAULT_CACHE_ALIAS = 'default'
diff --git a/django/core/cache/backends/base.py b/django/core/cache/backends/base.py
index 1235f7e..db8cc37 100644
--- a/django/core/cache/backends/base.py
+++ b/django/core/cache/backends/base.py
@@ -16,6 +16,10 @@ class CacheKeyWarning(DjangoRuntimeWarning):
     pass
 
 
+class InvalidCacheKey(ValueError):
+    pass
+
+
 # Stub class to ensure not passing in a `timeout` argument results in
 # the default timeout
 DEFAULT_TIMEOUT = object()
@@ -233,18 +237,8 @@ class BaseCache(object):
         backend. This encourages (but does not force) writing backend-portable
         cache code.
         """
-        if len(key) > MEMCACHE_MAX_KEY_LENGTH:
-            warnings.warn(
-                'Cache key will cause errors if used with memcached: %r '
-                '(longer than %s)' % (key, MEMCACHE_MAX_KEY_LENGTH), CacheKeyWarning
-            )
-        for char in key:
-            if ord(char) < 33 or ord(char) == 127:
-                warnings.warn(
-                    'Cache key contains characters that will cause errors if '
-                    'used with memcached: %r' % key, CacheKeyWarning
-                )
-                break
+        for warning in memcache_key_warnings(key):
+            warnings.warn(warning, CacheKeyWarning)
 
     def incr_version(self, key, delta=1, version=None):
         """Adds delta to the cache version for the supplied key. Returns the
@@ -270,3 +264,18 @@ class BaseCache(object):
     def close(self, **kwargs):
         """Close the cache connection"""
         pass
+
+
+def memcache_key_warnings(key):
+    if len(key) > MEMCACHE_MAX_KEY_LENGTH:
+        yield (
+            'Cache key will cause errors if used with memcached: %r '
+            '(longer than %s)' % (key, MEMCACHE_MAX_KEY_LENGTH)
+        )
+    for char in key:
+        if ord(char) < 33 or ord(char) == 127:
+            yield (
+                'Cache key contains characters that will cause errors if '
+                'used with memcached: %r' % key
+            )
+            break
diff --git a/django/core/cache/backends/memcached.py b/django/core/cache/backends/memcached.py
index 4cf25fb..b77ea30 100644
--- a/django/core/cache/backends/memcached.py
+++ b/django/core/cache/backends/memcached.py
@@ -5,7 +5,9 @@ import re
 import time
 import warnings
 
-from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
+from django.core.cache.backends.base import (
+    DEFAULT_TIMEOUT, BaseCache, InvalidCacheKey, memcache_key_warnings,
+)
 from django.utils import six
 from django.utils.deprecation import RemovedInDjango21Warning
 from django.utils.encoding import force_str
@@ -72,10 +74,12 @@ class BaseMemcachedCache(BaseCache):
 
     def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
         key = self.make_key(key, version=version)
+        self.validate_key(key)
         return self._cache.add(key, value, self.get_backend_timeout(timeout))
 
     def get(self, key, default=None, version=None):
         key = self.make_key(key, version=version)
+        self.validate_key(key)
         val = self._cache.get(key)
         if val is None:
             return default
@@ -83,16 +87,20 @@ class BaseMemcachedCache(BaseCache):
 
     def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
         key = self.make_key(key, version=version)
+        self.validate_key(key)
         if not self._cache.set(key, value, self.get_backend_timeout(timeout)):
             # make sure the key doesn't keep its old value in case of failure to set (memcached's 1MB limit)
             self._cache.delete(key)
 
     def delete(self, key, version=None):
         key = self.make_key(key, version=version)
+        self.validate_key(key)
         self._cache.delete(key)
 
     def get_many(self, keys, version=None):
         new_keys = [self.make_key(x, version=version) for x in keys]
+        for key in new_keys:
+            self.validate_key(key)
         ret = self._cache.get_multi(new_keys)
         if ret:
             _ = {}
@@ -108,6 +116,7 @@ class BaseMemcachedCache(BaseCache):
 
     def incr(self, key, delta=1, version=None):
         key = self.make_key(key, version=version)
+        self.validate_key(key)
         # memcached doesn't support a negative delta
         if delta < 0:
             return self._cache.decr(key, -delta)
@@ -126,6 +135,7 @@ class BaseMemcachedCache(BaseCache):
 
     def decr(self, key, delta=1, version=None):
         key = self.make_key(key, version=version)
+        self.validate_key(key)
         # memcached doesn't support a negative delta
         if delta < 0:
             return self._cache.incr(key, -delta)
@@ -146,15 +156,23 @@ class BaseMemcachedCache(BaseCache):
         safe_data = {}
         for key, value in data.items():
             key = self.make_key(key, version=version)
+            self.validate_key(key)
             safe_data[key] = value
         self._cache.set_multi(safe_data, self.get_backend_timeout(timeout))
 
     def delete_many(self, keys, version=None):
-        self._cache.delete_multi(self.make_key(key, version=version) for key in keys)
+        keys = [self.make_key(key, version=version) for key in keys]
+        for key in keys:
+            self.validate_key(key)
+        self._cache.delete_multi(keys)
 
     def clear(self):
         self._cache.flush_all()
 
+    def validate_key(self, key):
+        for warning in memcache_key_warnings(key):
+            raise InvalidCacheKey(warning)
+
 
 class MemcachedCache(BaseMemcachedCache):
     "An implementation of a cache binding using python-memcached"
