From 9428df4ba027dea422697cfae995568cd06cd06a Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sun, 23 May 2021 18:51:27 +0200
Subject: [PATCH] Use constant-time comparison for passwords.

Backport of c91b4c2a to 8.1.
---
 src/websockets/auth.py | 29 ++++++++++++++++-------------
 1 file changed, 16 insertions(+), 13 deletions(-)

diff --git a/src/websockets/auth.py b/src/websockets/auth.py
index ae204b8..aeaf15b 100644
--- a/src/websockets/auth.py
+++ b/src/websockets/auth.py
@@ -6,7 +6,9 @@
 
 
 import functools
+import hmac
 import http
+from typing import cast
 from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union
 
 from .exceptions import InvalidHeader
@@ -137,24 +139,25 @@ def basic_auth_protocol_factory(
 
     if credentials is not None:
         if is_credentials(credentials):
-
-            async def check_credentials(username: str, password: str) -> bool:
-                return (username, password) == credentials
-
+            credentials_list = [cast(Credentials, credentials)]
         elif isinstance(credentials, Iterable):
             credentials_list = list(credentials)
-            if all(is_credentials(item) for item in credentials_list):
-                credentials_dict = dict(credentials_list)
-
-                async def check_credentials(username: str, password: str) -> bool:
-                    return credentials_dict.get(username) == password
-
-            else:
+            if not all(is_credentials(item) for item in credentials_list):
                 raise TypeError(f"invalid credentials argument: {credentials}")
-
         else:
             raise TypeError(f"invalid credentials argument: {credentials}")
 
+        credentials_dict = dict(credentials_list)
+
+        async def check_credentials(username: str, password: str) -> bool:
+            try:
+                expected_password = credentials_dict[username]
+            except KeyError:
+                return False
+            return hmac.compare_digest(expected_password, password)
+
     return functools.partial(
-        create_protocol, realm=realm, check_credentials=check_credentials
+        create_protocol,
+        realm=realm,
+        check_credentials=check_credentials,
     )
-- 
2.40.1

