commit 0bc73254f41acb140187e0c89606311f88de5b7b
Author: Ron Frederick <ronf@timeheart.net>
Date:   Mon Dec 18 07:41:57 2023 -0800

    Implement "strict kex" support to harden AsyncSSH against Terrapin Attack
    
    This commit implements "strict kex" support and other countermeasures to
    protect against the Terrapin Attack described in CVE-2023-48795. Thanks
    once again go to Fabian Bäumer, Marcus Brinkmann, and Jörg Schwenk for
    identifying and reporting this vulnerability and providing detailed
    analysis and suggestions about proposed fixes.

Index: b/asyncssh/connection.py
===================================================================
--- a/asyncssh/connection.py
+++ b/asyncssh/connection.py
@@ -810,6 +810,7 @@ class SSHConnection(SSHPacketHandler, as
         self._kexinit_sent = False
         self._kex_complete = False
         self._ignore_first_kex = False
+        self._strict_kex = False
 
         self._gss: Optional[GSSBase] = None
         self._gss_kex = False
@@ -1343,10 +1344,13 @@ class SSHConnection(SSHPacketHandler, as
             (alg_type, b','.join(local_algs).decode('ascii'),
              b','.join(remote_algs).decode('ascii')))
 
-    def _get_ext_info_kex_alg(self) -> List[bytes]:
-        """Return the kex alg to add if any to request extension info"""
+    def _get_extra_kex_algs(self) -> List[bytes]:
+        """Return the extra kex algs to add"""
 
-        return [b'ext-info-c' if self.is_client() else b'ext-info-s']
+        if self.is_client():
+            return [b'ext-info-c', b'kex-strict-c-v00@openssh.com']
+        else:
+            return [b'ext-info-s', b'kex-strict-s-v00@openssh.com']
 
     def _send(self, data: bytes) -> None:
         """Send data to the SSH connection"""
@@ -1487,6 +1491,11 @@ class SSHConnection(SSHPacketHandler, as
                 self._ignore_first_kex = False
             else:
                 handler = self._kex
+        elif self._strict_kex and not self._recv_encryption and \
+                MSG_IGNORE <= pkttype <= MSG_DEBUG:
+            skip_reason = 'strict kex violation'
+            exc_reason = 'Strict key exchange violation: ' \
+                         'unexpected packet type %d received' % pkttype
         elif (self._auth and
               MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST):
             handler = self._auth
@@ -1516,15 +1525,26 @@ class SSHConnection(SSHPacketHandler, as
                 raise ProtocolError(str(exc)) from None
 
             if not processed:
-                self.logger.debug1('Unknown packet type %d received', pkttype)
-                self.send_packet(MSG_UNIMPLEMENTED, UInt32(seq))
+                if self._strict_kex and not self._recv_encryption:
+                    exc_reason = 'Strict key exchange violation: ' \
+                                 'unexpected packet type %d received' % pkttype
+                else:
+                    self.logger.debug1('Unknown packet type %d received',
+                                       pkttype)
+                    self.send_packet(MSG_UNIMPLEMENTED, UInt32(seq))
 
         if exc_reason:
             raise ProtocolError(exc_reason)
 
         if self._transport:
-            self._recv_seq = (seq + 1) & 0xffffffff
             self._recv_handler = self._recv_pkthdr
+            if self._recv_seq == 0xffffffff and not self._recv_encryption:
+                raise ProtocolError('Sequence rollover before kex complete')
+
+            if pkttype == MSG_NEWKEYS and self._strict_kex:
+                self._recv_seq = 0
+            else:
+                self._recv_seq = (seq + 1) & 0xffffffff
 
         return True
 
@@ -1579,7 +1599,15 @@ class SSHConnection(SSHPacketHandler, as
             mac = b''
 
         self._send(packet + mac)
-        self._send_seq = (seq + 1) & 0xffffffff
+
+        if self._send_seq == 0xffffffff and not self._send_encryption:
+            self._send_seq = 0
+            raise ProtocolError('Sequence rollover before kex complete')
+
+        if pkttype == MSG_NEWKEYS and self._strict_kex:
+            self._send_seq = 0
+        else:
+            self._send_seq = (seq + 1) & 0xffffffff
 
         if self._kex_complete:
             self._rekey_bytes_sent += pktlen
@@ -1623,7 +1651,7 @@ class SSHConnection(SSHPacketHandler, as
 
         kex_algs = expand_kex_algs(self._kex_algs, gss_mechs,
                                    bool(self._server_host_key_algs)) + \
-                   self._get_ext_info_kex_alg()
+                   self._get_extra_kex_algs()
 
         host_key_algs = self._server_host_key_algs or [b'null']
 
@@ -2106,13 +2134,26 @@ class SSHConnection(SSHPacketHandler, as
         if self.is_server():
             self._client_kexinit = packet.get_consumed_payload()
 
-            if b'ext-info-c' in peer_kex_algs and not self._session_id:
-                self._can_send_ext_info = True
+            if not self._session_id:
+                if b'ext-info-c' in peer_kex_algs:
+                    self._can_send_ext_info = True
+
+                if b'kex-strict-c-v00@openssh.com' in peer_kex_algs:
+                    self._strict_kex = True
         else:
             self._server_kexinit = packet.get_consumed_payload()
 
-            if b'ext-info-s' in peer_kex_algs and not self._session_id:
-                self._can_send_ext_info = True
+            if not self._session_id:
+                if b'ext-info-s' in peer_kex_algs:
+                    self._can_send_ext_info = True
+
+                if b'kex-strict-s-v00@openssh.com' in peer_kex_algs:
+                    self._strict_kex = True
+
+        if self._strict_kex and not self._recv_encryption and \
+                self._recv_seq != 0:
+            raise ProtocolError('Strict key exchange violation: '
+                                'KEXINIT was not the first packet')
 
         if self._kexinit_sent:
             self._kexinit_sent = False
Index: b/tests/test_connection.py
===================================================================
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -31,9 +31,10 @@ import unittest
 from unittest.mock import patch
 
 import asyncssh
-from asyncssh.constants import MSG_UNIMPLEMENTED, MSG_DEBUG
+from asyncssh.constants import MSG_UNIMPLEMENTED, MSG_DEBUG, MSG_IGNORE
 from asyncssh.constants import MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT
 from asyncssh.constants import MSG_KEXINIT, MSG_NEWKEYS
+from asyncssh.constants import MSG_KEX_FIRST, MSG_KEX_LAST
 from asyncssh.constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS
 from asyncssh.constants import MSG_USERAUTH_FAILURE, MSG_USERAUTH_BANNER
 from asyncssh.constants import MSG_GLOBAL_REQUEST
@@ -43,6 +44,7 @@ from asyncssh.compression import get_com
 from asyncssh.crypto.cipher import GCMCipher
 from asyncssh.encryption import get_encryption_algs
 from asyncssh.kex import get_kex_algs
+from asyncssh.kex_dh import MSG_KEX_ECDH_REPLY
 from asyncssh.mac import _HMAC, _mac_handler, get_mac_algs
 from asyncssh.packet import Boolean, NameList, String, UInt32
 from asyncssh.public_key import get_default_public_key_algs
@@ -51,8 +53,8 @@ from asyncssh.public_key import get_defa
 
 from .server import Server, ServerTestCase
 
-from .util import asynctest, gss_available, patch_gss, run
-from .util import patch_getnameinfo, x509_available
+from .util import asynctest, patch_extra_kex, patch_getnameinfo, patch_gss, run
+from .util import gss_available, x509_available
 
 
 try:
@@ -901,22 +903,6 @@ class _TestConnection(ServerTestCase):
                 await self.connect(kex_algs=['fail'])
 
     @asynctest
-    async def test_skip_ext_info(self):
-        """Test not requesting extension info from the server"""
-
-        def skip_ext_info(self):
-            """Don't request extension information"""
-
-            # pylint: disable=unused-argument
-
-            return []
-
-        with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg',
-                   skip_ext_info):
-            async with self.connect():
-                pass
-
-    @asynctest
     async def test_unknown_ext_info(self):
         """Test receiving unknown extension information"""
 
@@ -941,6 +927,54 @@ class _TestConnection(ServerTestCase):
                 pass
 
     @asynctest
+    async def test_message_before_kexinit_strict_kex(self):
+        """Test receiving a message before KEXINIT with strict_kex enabled"""
+
+        def send_packet(self, pkttype, *args, **kwargs):
+            if pkttype == MSG_KEXINIT:
+                self.send_packet(MSG_IGNORE, String(b''))
+
+            asyncssh.connection.SSHConnection.send_packet(
+                self, pkttype, *args, **kwargs)
+
+        with patch('asyncssh.connection.SSHClientConnection.send_packet',
+                   send_packet):
+            with self.assertRaises(asyncssh.ProtocolError):
+                await self.connect()
+
+    @asynctest
+    async def test_message_during_kex_strict_kex(self):
+        """Test receiving an unexpected message with strict_kex enabled"""
+
+        def send_packet(self, pkttype, *args, **kwargs):
+            if pkttype == MSG_KEX_ECDH_REPLY:
+                self.send_packet(MSG_IGNORE, String(b''))
+
+            asyncssh.connection.SSHConnection.send_packet(
+                self, pkttype, *args, **kwargs)
+
+        with patch('asyncssh.connection.SSHServerConnection.send_packet',
+                   send_packet):
+            with self.assertRaises(asyncssh.ProtocolError):
+                await self.connect()
+
+    @asynctest
+    async def test_unknown_message_during_kex_strict_kex(self):
+        """Test receiving an unknown message with strict_kex enabled"""
+
+        def send_packet(self, pkttype, *args, **kwargs):
+            if pkttype == MSG_KEX_ECDH_REPLY:
+                self.send_packet(MSG_KEX_LAST)
+
+            asyncssh.connection.SSHConnection.send_packet(
+                self, pkttype, *args, **kwargs)
+
+        with patch('asyncssh.connection.SSHServerConnection.send_packet',
+                   send_packet):
+            with self.assertRaises(asyncssh.ProtocolError):
+                await self.connect()
+
+    @asynctest
     async def test_encryption_algs(self):
         """Test connecting with different encryption algorithms"""
 
@@ -1468,6 +1502,81 @@ class _TestConnection(ServerTestCase):
             await self.create_connection(_InternalErrorClient)
 
 
+@patch_extra_kex
+class _TestConnectionNoStrictKex(ServerTestCase):
+    """Unit tests for connection API with ext info and strict kex disabled"""
+
+    @classmethod
+    async def start_server(cls):
+        """Start an SSH server to connect to"""
+
+        return (await cls.create_server(_TunnelServer, gss_host=(),
+                                        compression_algs='*',
+                                        encryption_algs='*',
+                                        kex_algs='*', mac_algs='*'))
+
+    @asynctest
+    async def test_skip_ext_info(self):
+        """Test not requesting extension info from the server"""
+
+        async with self.connect():
+            pass
+
+    @asynctest
+    async def test_message_before_kexinit(self):
+        """Test receiving a message before KEXINIT"""
+
+        def send_packet(self, pkttype, *args, **kwargs):
+            if pkttype == MSG_KEXINIT:
+                self.send_packet(MSG_IGNORE, String(b''))
+
+            asyncssh.connection.SSHConnection.send_packet(
+                self, pkttype, *args, **kwargs)
+
+        with patch('asyncssh.connection.SSHClientConnection.send_packet',
+                   send_packet):
+            async with self.connect():
+                pass
+
+    @asynctest
+    async def test_message_during_kex(self):
+        """Test receiving an unexpected message in key exchange"""
+
+        def send_packet(self, pkttype, *args, **kwargs):
+            if pkttype == MSG_KEX_ECDH_REPLY:
+                self.send_packet(MSG_IGNORE, String(b''))
+
+            asyncssh.connection.SSHConnection.send_packet(
+                self, pkttype, *args, **kwargs)
+
+        with patch('asyncssh.connection.SSHServerConnection.send_packet',
+                   send_packet):
+            async with self.connect():
+                pass
+
+    @asynctest
+    async def test_sequence_wrap_during_kex(self):
+        """Test sequence wrap during initial key exchange"""
+
+        def send_packet(self, pkttype, *args, **kwargs):
+            if pkttype == MSG_KEXINIT:
+                if self._options.command == 'send':
+                    self._send_seq = 0xfffffffe
+                else:
+                    self._recv_seq = 0xfffffffe
+
+            asyncssh.connection.SSHConnection.send_packet(
+                self, pkttype, *args, **kwargs)
+
+        with patch('asyncssh.connection.SSHClientConnection.send_packet',
+                   send_packet):
+            with self.assertRaises(asyncssh.ProtocolError):
+                await self.connect(command='send')
+
+            with self.assertRaises(asyncssh.ProtocolError):
+                await self.connect(command='recv')
+
+
 class _TestConnectionAsyncAcceptor(ServerTestCase):
     """Unit test for async acceptor"""
 
Index: b/tests/test_connection_auth.py
===================================================================
--- a/tests/test_connection_auth.py
+++ b/tests/test_connection_auth.py
@@ -710,7 +710,7 @@ class _TestHostBasedAuth(ServerTestCase)
 
             return []
 
-        with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg',
+        with patch('asyncssh.connection.SSHConnection._get_extra_kex_algs',
                    skip_ext_info):
             async with self.connect(username='user', client_host_keys='skey',
                                     client_username='user'):
@@ -1209,7 +1209,7 @@ class _TestPublicKeyAuth(ServerTestCase)
 
             return []
 
-        with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg',
+        with patch('asyncssh.connection.SSHConnection._get_extra_kex_algs',
                    skip_ext_info):
             async with self.connect(username='ckey', client_keys='ckey',
                                     agent_path=None):
Index: b/tests/util.py
===================================================================
--- a/tests/util.py
+++ b/tests/util.py
@@ -96,6 +96,20 @@ def patch_getnameinfo(cls):
     return patch('socket.getnameinfo', getnameinfo)(cls)
 
 
+def patch_extra_kex(cls):
+    """Decorator for skipping extra kex algs"""
+
+    def skip_extra_kex_algs(self):
+        """Don't send extra key exchange algorithms"""
+
+        # pylint: disable=unused-argument
+
+        return []
+
+    return patch('asyncssh.connection.SSHConnection._get_extra_kex_algs',
+                 skip_extra_kex_algs)(cls)
+
+
 def patch_gss(cls):
     """Decorator for patching GSSAPI classes"""
 
