1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
|
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import AsyncIterator, List
import pytest
import httpx
import uvicorn
from async_timeout import timeout
import portend
from tenacity import retry, stop_after_attempt, wait_exponential
logger = logging.getLogger(__name__)
class ServerManager:
"""Manages the lifecycle of a test server instance"""
def __init__(self, app_path: str, host: str = "localhost", port: int = None):
self.app_path = app_path
self.host = host
self.port = port or portend.find_available_local_port()
self.server = None
self._startup_complete = asyncio.Event()
self._shutdown_complete = asyncio.Event()
async def startup(self) -> None:
"""Start the server in a separate task"""
config = uvicorn.Config(
app=self.app_path,
host=self.host,
port=self.port,
log_level="error",
loop="asyncio",
)
self.server = uvicorn.Server(config=config)
# Store the original startup handler
original_startup = self.server.startup
# Create a wrapper that preserves the original signature
async def startup_wrapper(*args, **kwargs):
await original_startup(*args, **kwargs)
self._startup_complete.set()
self.server.startup = startup_wrapper
# Start the server
self._server_task = asyncio.create_task(self.server.serve())
try:
async with timeout(10): # 10 second timeout for startup
await self._startup_complete.wait()
# Additional health check
retry_count = 0
while retry_count < 5: # Try 5 times with exponential backoff
if await self.health_check():
break
retry_count += 1
await asyncio.sleep(0.2 * (2**retry_count))
else:
raise RuntimeError("Server health check failed after retries")
except Exception as e:
# If startup fails, ensure we clean up
await self.shutdown()
if isinstance(e, asyncio.TimeoutError):
raise RuntimeError("Server failed to start within timeout") from e
raise
async def shutdown(self) -> None:
"""Shutdown the server gracefully"""
if self.server and not self._shutdown_complete.is_set():
try:
self.server.should_exit = True
if hasattr(self, "_server_task"):
try:
async with timeout(5): # 5 second timeout for shutdown
await self._server_task
except asyncio.TimeoutError:
# Force cancel if graceful shutdown fails
self._server_task.cancel()
try:
await self._server_task
except asyncio.CancelledError:
pass
finally:
self._shutdown_complete.set()
self.server = None
self._startup_complete.clear()
@property
def url(self) -> str:
return f"http://{self.host}:{self.port}"
async def health_check(self) -> bool:
"""Check if server is responding"""
try:
async with httpx.AsyncClient() as client:
async with timeout(1): # 1 second timeout for health check
response = await client.get(f"{self.url}/health")
return response.status_code == 200
except Exception:
return False
@asynccontextmanager
async def server_context(app_path: str) -> AsyncIterator[ServerManager]:
"""Context manager for server lifecycle"""
server = ServerManager(app_path)
try:
await server.startup()
yield server
finally:
await server.shutdown()
class SSEClient:
"""Client for consuming SSE streams"""
def __init__(self, url: str, expected_lines: int):
self.url = url
self.expected_lines = expected_lines
self.received_lines = 0
self.errors: List[Exception] = []
@retry(
stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)
)
async def connect_and_consume(self) -> None:
"""Connect to SSE stream and consume messages"""
try:
async with httpx.AsyncClient() as client:
async with timeout(20): # 20 second timeout for stream consumption
async with client.stream("GET", self.url) as response:
async for line in response.aiter_lines():
if line.strip(): # Only count non-empty lines
self.received_lines += 1
if self.received_lines >= self.expected_lines:
break
except Exception as e:
self.errors.append(e)
raise
@pytest.mark.asyncio
@pytest.mark.experimentation
@pytest.mark.parametrize(
"app_path,expected_lines",
[
("tests.integration.main_endless:app", 14),
("tests.integration.main_endless_conditional:app", 2),
],
)
async def test_sse_multiple_consumers(
app_path: str, expected_lines: int, num_consumers: int = 3
):
"""Test multiple consumers connecting to SSE endpoint"""
async with server_context(app_path) as server:
# Create and start consumers
clients = [
SSEClient(f"{server.url}/endless", expected_lines)
for _ in range(num_consumers)
]
# Run consumers concurrently with timeout
async with timeout(30): # 30 second timeout for entire test
try:
# Create tasks for all consumers
consumer_tasks = [
asyncio.create_task(client.connect_and_consume())
for client in clients
]
# Wait for all consumers or first error
done, pending = await asyncio.wait(
consumer_tasks, return_when=asyncio.FIRST_EXCEPTION
)
# Cancel any pending tasks
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Check results and gather errors
errors = []
for task in done:
try:
await task
except Exception as e:
errors.append(e)
# Verify expectations
for i, client in enumerate(clients):
assert (
client.received_lines == expected_lines
), f"Client {i} received {client.received_lines} lines, expected {expected_lines}"
assert not errors, f"Consumers encountered errors: {errors}"
except asyncio.TimeoutError:
raise RuntimeError("Test timed out waiting for consumers")
finally:
# Ensure all tasks are properly cleaned up
for task in consumer_tasks:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
|