1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
From: Yang Wang <yang.wang@windriver.com>
Date: Mon, 28 Jul 2025 11:41:09 +0200
Subject: fix-cve-2024-28849-async-write
Fix CVE-2025-54121: Avoid event loop blocking during multipart file uploads
by writing to disk using thread pool to prevent synchronous blocking when
SpooledTemporaryFile rolls over to disk. (Closes: #1109805)
---
starlette/datastructures.py | 22 +++++++++++++---
tests/test_formparsers.py | 63 ++++++++++++++++++++++++++++++++++++++++++++-
2 files changed, 80 insertions(+), 5 deletions(-)
diff --git a/starlette/datastructures.py b/starlette/datastructures.py
index f5d74d2..9957090 100644
--- a/starlette/datastructures.py
+++ b/starlette/datastructures.py
@@ -424,6 +424,10 @@ class UploadFile:
self.size = size
self.headers = headers or Headers()
+ # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
+ # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
+ self._max_mem_size = getattr(self.file, "_max_size", 0)
+
@property
def content_type(self) -> str | None:
return self.headers.get("content-type", None)
@@ -434,14 +438,24 @@ class UploadFile:
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk
+ def _will_roll(self, size_to_add: int) -> bool:
+ # If we're not in_memory then we will always roll
+ if not self._in_memory:
+ return True
+
+ # Check for SpooledTemporaryFile._max_size
+ future_size = self.file.tell() + size_to_add
+ return bool(future_size > self._max_mem_size) if self._max_mem_size else False
+
async def write(self, data: bytes) -> None:
+ new_data_len = len(data)
if self.size is not None:
- self.size += len(data)
+ self.size += new_data_len
- if self._in_memory:
- self.file.write(data)
- else:
+ if self._will_roll(new_data_len):
await run_in_threadpool(self.file.write, data)
+ else:
+ self.file.write(data)
async def read(self, size: int = -1) -> bytes:
if self._in_memory:
diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py
index b18fd6c..63577d6 100644
--- a/tests/test_formparsers.py
+++ b/tests/test_formparsers.py
@@ -1,15 +1,20 @@
from __future__ import annotations
import os
+import threading
+from collections.abc import Generator
import typing
from contextlib import nullcontext as does_not_raise
+from io import BytesIO
from pathlib import Path
+from tempfile import SpooledTemporaryFile
+from unittest import mock
import pytest
from starlette.applications import Starlette
from starlette.datastructures import UploadFile
-from starlette.formparsers import MultiPartException, _user_safe_decode
+from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount
@@ -104,6 +109,22 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None:
await response(scope, receive, send)
+async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None:
+ """Helper app to monitor what thread the app was called on.
+
+ This can later be used to validate thread/event loop operations.
+ """
+ request = Request(scope, receive)
+
+ # Make sure we parse the form
+ await request.form()
+ await request.close()
+
+ # Send back the current thread id
+ response = JSONResponse({"thread_ident": threading.current_thread().ident})
+ await response(scope, receive, send)
+
+
def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
@@ -302,6 +323,46 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor
"field1": "value1",
}
+class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]):
+ """Helper class to track which threads performed the rollover operation.
+
+ This is not threadsafe/multi-test safe.
+ """
+
+ rollover_threads: typing.ClassVar[set[int | None]] = set()
+
+ def rollover(self) -> None:
+ ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident)
+ super().rollover()
+
+
+@pytest.fixture
+def mock_spooled_temporary_file() -> Generator[None]:
+ try:
+ with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile):
+ yield
+ finally:
+ ThreadTrackingSpooledTemporaryFile.rollover_threads.clear()
+
+
+def test_multipart_request_large_file_rollover_in_background_thread(
+ mock_spooled_temporary_file: None, test_client_factory: TestClientFactory
+) -> None:
+ """Test that Spooled file rollovers happen in background threads."""
+ data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1))
+
+ client = test_client_factory(app_monitor_thread)
+ response = client.post("/", files=[("test_large", data)])
+ assert response.status_code == 200
+
+ # Parse the event thread id from the API response and ensure we have one
+ app_thread_ident = response.json().get("thread_ident")
+ assert app_thread_ident is not None
+
+ # Ensure the app thread was not the same as the rollover one and that a rollover thread exists
+ assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads
+ assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1
+
def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
|