1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
|
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import asyncpg
from asyncpg import connection as pg_connection
from asyncpg import _testbase as tb
MAX_RUNTIME = 0.5
class TestTimeout(tb.ConnectedTestCase):
async def test_timeout_01(self):
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
meth = getattr(self.con, methname)
await meth('select pg_sleep(10)', timeout=0.02)
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async def test_timeout_02(self):
st = await self.con.prepare('select pg_sleep(10)')
for methname in {'fetch', 'fetchrow', 'fetchval'}:
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
meth = getattr(st, methname)
await meth(timeout=0.02)
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async def test_timeout_03(self):
task = self.loop.create_task(
self.con.fetch('select pg_sleep(10)', timeout=0.2))
await asyncio.sleep(0.05)
task.cancel()
with self.assertRaises(asyncio.CancelledError), \
self.assertRunUnder(MAX_RUNTIME):
await task
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async def test_timeout_04(self):
st = await self.con.prepare('select pg_sleep(10)', timeout=0.1)
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
async with self.con.transaction():
async for _ in st.cursor(timeout=0.1): # NOQA
pass
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
st = await self.con.prepare('select pg_sleep(10)', timeout=0.1)
async with self.con.transaction():
cur = await st.cursor()
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
await cur.fetch(1, timeout=0.1)
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async def test_timeout_05(self):
# Stress-test timeouts - try to trigger a race condition
# between a cancellation request to Postgres and next
# query (SELECT 1)
for _ in range(500):
with self.assertRaises(asyncio.TimeoutError):
await self.con.fetch('SELECT pg_sleep(1)', timeout=1e-10)
self.assertEqual(await self.con.fetch('SELECT 1'), [(1,)])
async def test_timeout_06(self):
async with self.con.transaction():
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
async for _ in self.con.cursor( # NOQA
'select pg_sleep(10)', timeout=0.1):
pass
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async with self.con.transaction():
cur = await self.con.cursor('select pg_sleep(10)')
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
await cur.fetch(1, timeout=0.1)
async with self.con.transaction():
cur = await self.con.cursor('select pg_sleep(10)')
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
await cur.forward(1, timeout=1e-10)
async with self.con.transaction():
cur = await self.con.cursor('select pg_sleep(10)')
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
await cur.fetchrow(timeout=0.1)
async with self.con.transaction():
cur = await self.con.cursor('select pg_sleep(10)')
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
await cur.fetchrow(timeout=0.1)
with self.assertRaises(asyncpg.InFailedSQLTransactionError):
await cur.fetch(1)
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
async def test_invalid_timeout(self):
for command_timeout in ('a', False, -1):
with self.subTest(command_timeout=command_timeout):
with self.assertRaisesRegex(ValueError,
'invalid command_timeout'):
await self.connect(command_timeout=command_timeout)
# Note: negative timeouts are OK for method calls.
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
for timeout in ('a', False):
with self.subTest(timeout=timeout):
with self.assertRaisesRegex(ValueError, 'invalid timeout'):
await self.con.execute('SELECT 1', timeout=timeout)
class TestConnectionCommandTimeout(tb.ConnectedTestCase):
@tb.with_connection_options(command_timeout=0.2)
async def test_command_timeout_01(self):
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
with self.assertRaises(asyncio.TimeoutError), \
self.assertRunUnder(MAX_RUNTIME):
meth = getattr(self.con, methname)
await meth('select pg_sleep(10)')
self.assertEqual(await self.con.fetch('select 1'), [(1,)])
class SlowPrepareConnection(pg_connection.Connection):
"""Connection class to test timeouts."""
async def _get_statement(self, query, timeout, **kwargs):
await asyncio.sleep(0.3)
return await super()._get_statement(query, timeout, **kwargs)
class TestTimeoutCoversPrepare(tb.ConnectedTestCase):
@tb.with_connection_options(connection_class=SlowPrepareConnection,
command_timeout=0.3)
async def test_timeout_covers_prepare_01(self):
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:
with self.assertRaises(asyncio.TimeoutError):
meth = getattr(self.con, methname)
await meth('select pg_sleep($1)', 0.2)
|