File: 0002-fix-cve-2024-28849-async-write.patch

package info (click to toggle)
starlette 0.46.1-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,936 kB
  • sloc: python: 13,029; sh: 101; makefile: 6
file content (152 lines) | stat: -rw-r--r-- 5,899 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
From: Yang Wang <yang.wang@windriver.com>
Date: Mon, 28 Jul 2025 11:41:09 +0200
Subject: fix-cve-2024-28849-async-write

Fix CVE-2025-54121: Avoid event loop blocking during multipart file uploads
by writing to disk using thread pool to prevent synchronous blocking when
SpooledTemporaryFile rolls over to disk. (Closes: #1109805)
---
 starlette/datastructures.py | 22 +++++++++++++---
 tests/test_formparsers.py   | 63 ++++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 80 insertions(+), 5 deletions(-)

diff --git a/starlette/datastructures.py b/starlette/datastructures.py
index f5d74d2..9957090 100644
--- a/starlette/datastructures.py
+++ b/starlette/datastructures.py
@@ -424,6 +424,10 @@ class UploadFile:
         self.size = size
         self.headers = headers or Headers()
 
+        # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
+        # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
+        self._max_mem_size = getattr(self.file, "_max_size", 0)
+
     @property
     def content_type(self) -> str | None:
         return self.headers.get("content-type", None)
@@ -434,14 +438,24 @@ class UploadFile:
         rolled_to_disk = getattr(self.file, "_rolled", True)
         return not rolled_to_disk
 
+    def _will_roll(self, size_to_add: int) -> bool:
+        # If we're not in_memory then we will always roll
+        if not self._in_memory:
+            return True
+
+        # Check for SpooledTemporaryFile._max_size
+        future_size = self.file.tell() + size_to_add
+        return bool(future_size > self._max_mem_size) if self._max_mem_size else False
+
     async def write(self, data: bytes) -> None:
+        new_data_len = len(data)
         if self.size is not None:
-            self.size += len(data)
+            self.size += new_data_len
 
-        if self._in_memory:
-            self.file.write(data)
-        else:
+        if self._will_roll(new_data_len):
             await run_in_threadpool(self.file.write, data)
+        else:
+            self.file.write(data)
 
     async def read(self, size: int = -1) -> bytes:
         if self._in_memory:
diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py
index b18fd6c..63577d6 100644
--- a/tests/test_formparsers.py
+++ b/tests/test_formparsers.py
@@ -1,15 +1,20 @@
 from __future__ import annotations
 
 import os
+import threading
+from collections.abc import Generator
 import typing
 from contextlib import nullcontext as does_not_raise
+from io import BytesIO
 from pathlib import Path
+from tempfile import SpooledTemporaryFile
+from unittest import mock
 
 import pytest
 
 from starlette.applications import Starlette
 from starlette.datastructures import UploadFile
-from starlette.formparsers import MultiPartException, _user_safe_decode
+from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode
 from starlette.requests import Request
 from starlette.responses import JSONResponse
 from starlette.routing import Mount
@@ -104,6 +109,22 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None:
     await response(scope, receive, send)
 
 
+async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None:
+    """Helper app to monitor what thread the app was called on.
+
+    This can later be used to validate thread/event loop operations.
+    """
+    request = Request(scope, receive)
+
+    # Make sure we parse the form
+    await request.form()
+    await request.close()
+
+    # Send back the current thread id
+    response = JSONResponse({"thread_ident": threading.current_thread().ident})
+    await response(scope, receive, send)
+
+
 def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp:
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
         request = Request(scope, receive)
@@ -302,6 +323,46 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor
         "field1": "value1",
     }
 
+class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]):
+    """Helper class to track which threads performed the rollover operation.
+
+    This is not threadsafe/multi-test safe.
+    """
+
+    rollover_threads: typing.ClassVar[set[int | None]] = set()
+
+    def rollover(self) -> None:
+        ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident)
+        super().rollover()
+
+
+@pytest.fixture
+def mock_spooled_temporary_file() -> Generator[None]:
+    try:
+        with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile):
+            yield
+    finally:
+        ThreadTrackingSpooledTemporaryFile.rollover_threads.clear()
+
+
+def test_multipart_request_large_file_rollover_in_background_thread(
+    mock_spooled_temporary_file: None, test_client_factory: TestClientFactory
+) -> None:
+    """Test that Spooled file rollovers happen in background threads."""
+    data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1))
+
+    client = test_client_factory(app_monitor_thread)
+    response = client.post("/", files=[("test_large", data)])
+    assert response.status_code == 200
+
+    # Parse the event thread id from the API response and ensure we have one
+    app_thread_ident = response.json().get("thread_ident")
+    assert app_thread_ident is not None
+
+    # Ensure the app thread was not the same as the rollover one and that a rollover thread exists
+    assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads
+    assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1
+
 
 def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)