1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
From 8b97933b259f34e5c66a4a1ae46c6fc176e26999 Mon Sep 17 00:00:00 2001
From: "Nathaniel J. Smith" <njs@anthropic.com>
Date: Thu, 9 Jan 2025 23:41:42 -0800
Subject: Validate Chunked-Encoding chunk footer
Also add a bit more thoroughness to some tests that I noticed while I
was working on it.
Thanks to Jeppe Bonde Weikop for the report.
---
h11/_readers.py | 23 +++++++++++--------
h11/tests/test_io.py | 54 +++++++++++++++++++++++++++++++-------------
2 files changed, 51 insertions(+), 26 deletions(-)
diff --git a/h11/_readers.py b/h11/_readers.py
index 08a9574..1348565 100644
--- a/h11/_readers.py
+++ b/h11/_readers.py
@@ -148,10 +148,9 @@ chunk_header_re = re.compile(chunk_header.encode("ascii"))
class ChunkedReader:
def __init__(self) -> None:
self._bytes_in_chunk = 0
- # After reading a chunk, we have to throw away the trailing \r\n; if
- # this is >0 then we discard that many bytes before resuming regular
- # de-chunkification.
- self._bytes_to_discard = 0
+ # After reading a chunk, we have to throw away the trailing \r\n.
+ # This tracks the bytes that we need to match and throw away.
+ self._bytes_to_discard = b""
self._reading_trailer = False
def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]:
@@ -160,15 +159,19 @@ class ChunkedReader:
if lines is None:
return None
return EndOfMessage(headers=list(_decode_header_lines(lines)))
- if self._bytes_to_discard > 0:
- data = buf.maybe_extract_at_most(self._bytes_to_discard)
+ if self._bytes_to_discard:
+ data = buf.maybe_extract_at_most(len(self._bytes_to_discard))
if data is None:
return None
- self._bytes_to_discard -= len(data)
- if self._bytes_to_discard > 0:
+ if data != self._bytes_to_discard[:len(data)]:
+ raise LocalProtocolError(
+ f"malformed chunk footer: {data!r} (expected {self._bytes_to_discard!r})"
+ )
+ self._bytes_to_discard = self._bytes_to_discard[len(data):]
+ if self._bytes_to_discard:
return None
# else, fall through and read some more
- assert self._bytes_to_discard == 0
+ assert self._bytes_to_discard == b""
if self._bytes_in_chunk == 0:
# We need to refill our chunk count
chunk_header = buf.maybe_extract_next_line()
@@ -194,7 +197,7 @@ class ChunkedReader:
return None
self._bytes_in_chunk -= len(data)
if self._bytes_in_chunk == 0:
- self._bytes_to_discard = 2
+ self._bytes_to_discard = b"\r\n"
chunk_end = True
else:
chunk_end = False
diff --git a/h11/tests/test_io.py b/h11/tests/test_io.py
index 2b47c0e..634c49d 100644
--- a/h11/tests/test_io.py
+++ b/h11/tests/test_io.py
@@ -360,22 +360,34 @@ def _run_reader(*args: Any) -> List[Event]:
return normalize_data_events(events)
-def t_body_reader(thunk: Any, data: bytes, expected: Any, do_eof: bool = False) -> None:
+def t_body_reader(thunk: Any, data: bytes, expected: list, do_eof: bool = False) -> None:
# Simple: consume whole thing
print("Test 1")
buf = makebuf(data)
- assert _run_reader(thunk(), buf, do_eof) == expected
+ try:
+ assert _run_reader(thunk(), buf, do_eof) == expected
+ except LocalProtocolError:
+ if LocalProtocolError in expected:
+ pass
+ else:
+ raise
# Incrementally growing buffer
print("Test 2")
reader = thunk()
buf = ReceiveBuffer()
events = []
- for i in range(len(data)):
- events += _run_reader(reader, buf, False)
- buf += data[i : i + 1]
- events += _run_reader(reader, buf, do_eof)
- assert normalize_data_events(events) == expected
+ try:
+ for i in range(len(data)):
+ events += _run_reader(reader, buf, False)
+ buf += data[i : i + 1]
+ events += _run_reader(reader, buf, do_eof)
+ assert normalize_data_events(events) == expected
+ except LocalProtocolError:
+ if LocalProtocolError in expected:
+ pass
+ else:
+ raise
is_complete = any(type(event) is EndOfMessage for event in expected)
if is_complete and not do_eof:
@@ -436,14 +448,12 @@ def test_ChunkedReader() -> None:
)
# refuses arbitrarily long chunk integers
- with pytest.raises(LocalProtocolError):
- # Technically this is legal HTTP/1.1, but we refuse to process chunk
- # sizes that don't fit into 20 characters of hex
- t_body_reader(ChunkedReader, b"9" * 100 + b"\r\nxxx", [Data(data=b"xxx")])
+ # Technically this is legal HTTP/1.1, but we refuse to process chunk
+ # sizes that don't fit into 20 characters of hex
+ t_body_reader(ChunkedReader, b"9" * 100 + b"\r\nxxx", [LocalProtocolError])
# refuses garbage in the chunk count
- with pytest.raises(LocalProtocolError):
- t_body_reader(ChunkedReader, b"10\x00\r\nxxx", None)
+ t_body_reader(ChunkedReader, b"10\x00\r\nxxx", [LocalProtocolError])
# handles (and discards) "chunk extensions" omg wtf
t_body_reader(
@@ -457,10 +467,22 @@ def test_ChunkedReader() -> None:
t_body_reader(
ChunkedReader,
- b"5 \r\n01234\r\n" + b"0\r\n\r\n",
+ b"5 \t \r\n01234\r\n" + b"0\r\n\r\n",
[Data(data=b"01234"), EndOfMessage()],
)
+ # Chunked encoding with bad chunk termination characters are refused. Originally we
+ # simply dropped the 2 bytes after a chunk, instead of validating that the bytes
+ # were \r\n -- so we would successfully decode the data below as b"xxxa". And
+ # apparently there are other HTTP processors that ignore the chunk length and just
+ # keep reading until they see \r\n, so they would decode it as b"xxx__1a". Any time
+ # two HTTP processors accept the same input but interpret it differently, there's a
+ # possibility of request smuggling shenanigans. So we now reject this.
+ t_body_reader(ChunkedReader, b"3\r\nxxx__1a\r\n", [LocalProtocolError])
+
+ # Confirm we check both bytes individually
+ t_body_reader(ChunkedReader, b"3\r\nxxx\r_1a\r\n", [LocalProtocolError])
+ t_body_reader(ChunkedReader, b"3\r\nxxx_\n1a\r\n", [LocalProtocolError])
def test_ContentLengthWriter() -> None:
w = ContentLengthWriter(5)
@@ -483,8 +505,8 @@ def test_ContentLengthWriter() -> None:
dowrite(w, EndOfMessage())
w = ContentLengthWriter(5)
- dowrite(w, Data(data=b"123")) == b"123"
- dowrite(w, Data(data=b"45")) == b"45"
+ assert dowrite(w, Data(data=b"123")) == b"123"
+ assert dowrite(w, Data(data=b"45")) == b"45"
with pytest.raises(LocalProtocolError):
dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
--
2.30.2
|