From ede4582879f31cc29be54fdcdf8bc168dc7ea6e3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Helleu?= <flashcode@flashtux.org>
Date: Sat, 4 Sep 2021 23:09:19 +0200
Subject: relay: fix crash when decoding a malformed websocket frame

---
 src/plugins/relay/relay-websocket.c | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/src/plugins/relay/relay-websocket.c b/src/plugins/relay/relay-websocket.c
index e3b768d0a..789f67e20 100644
--- a/src/plugins/relay/relay-websocket.c
+++ b/src/plugins/relay/relay-websocket.c
@@ -278,7 +278,7 @@ relay_websocket_decode_frame (const unsigned char *buffer,
     index_buffer = 0;
 
     /* loop to decode all frames in message */
-    while (index_buffer + 2 <= buffer_length)
+    while (index_buffer + 1 < buffer_length)
     {
         opcode = buffer[index_buffer] & 15;
 
@@ -293,10 +293,12 @@ relay_websocket_decode_frame (const unsigned char *buffer,
         length_frame_size = 1;
         length_frame = buffer[index_buffer + 1] & 127;
         index_buffer += 2;
+        if (index_buffer >= buffer_length)
+            return 0;
         if ((length_frame == 126) || (length_frame == 127))
         {
             length_frame_size = (length_frame == 126) ? 2 : 8;
-            if (buffer_length < 1 + length_frame_size)
+            if (index_buffer + length_frame_size > buffer_length)
                 return 0;
             length_frame = 0;
             for (i = 0; i < length_frame_size; i++)
@@ -306,10 +308,9 @@ relay_websocket_decode_frame (const unsigned char *buffer,
             index_buffer += length_frame_size;
         }
 
-        if (buffer_length < 1 + length_frame_size + 4 + length_frame)
-            return 0;
-
         /* read masks (4 bytes) */
+        if (index_buffer + 4 > buffer_length)
+            return 0;
         int masks[4];
         for (i = 0; i < 4; i++)
         {
@@ -333,6 +334,11 @@ relay_websocket_decode_frame (const unsigned char *buffer,
         *decoded_length += 1;
 
         /* decode data using masks */
+        if ((length_frame > buffer_length)
+            || (index_buffer + length_frame > buffer_length))
+        {
+            return 0;
+        }
         for (i = 0; i < length_frame; i++)
         {
             decoded[*decoded_length + i] = (int)((unsigned char)buffer[index_buffer + i]) ^ masks[i % 4];
-- 
2.20.1

