1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
import operator
import unittest
from collections import namedtuple
from contextlib import contextmanager
from django.db import connection, models
from django.test import TestCase
from ..models import Person
@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL tests')
class ServerSideCursorsPostgres(TestCase):
cursor_fields = 'name, statement, is_holdable, is_binary, is_scrollable, creation_time'
PostgresCursor = namedtuple('PostgresCursor', cursor_fields)
@classmethod
def setUpTestData(cls):
Person.objects.create(first_name='a', last_name='a')
Person.objects.create(first_name='b', last_name='b')
def inspect_cursors(self):
with connection.cursor() as cursor:
cursor.execute('SELECT {fields} FROM pg_cursors;'.format(fields=self.cursor_fields))
cursors = cursor.fetchall()
return [self.PostgresCursor._make(cursor) for cursor in cursors]
@contextmanager
def override_db_setting(self, **kwargs):
for setting in kwargs:
original_value = connection.settings_dict.get(setting)
if setting in connection.settings_dict:
self.addCleanup(operator.setitem, connection.settings_dict, setting, original_value)
else:
self.addCleanup(operator.delitem, connection.settings_dict, setting)
connection.settings_dict[setting] = kwargs[setting]
yield
def assertUsesCursor(self, queryset, num_expected=1):
next(queryset) # Open a server-side cursor
cursors = self.inspect_cursors()
self.assertEqual(len(cursors), num_expected)
for cursor in cursors:
self.assertIn('_django_curs_', cursor.name)
self.assertFalse(cursor.is_scrollable)
self.assertFalse(cursor.is_holdable)
self.assertFalse(cursor.is_binary)
def asserNotUsesCursor(self, queryset):
self.assertUsesCursor(queryset, num_expected=0)
def test_server_side_cursor(self):
self.assertUsesCursor(Person.objects.iterator())
def test_values(self):
self.assertUsesCursor(Person.objects.values('first_name').iterator())
def test_values_list(self):
self.assertUsesCursor(Person.objects.values_list('first_name').iterator())
def test_values_list_flat(self):
self.assertUsesCursor(Person.objects.values_list('first_name', flat=True).iterator())
def test_values_list_fields_not_equal_to_names(self):
expr = models.Count('id')
self.assertUsesCursor(Person.objects.annotate(id__count=expr).values_list(expr, 'id__count').iterator())
def test_server_side_cursor_many_cursors(self):
persons = Person.objects.iterator()
persons2 = Person.objects.iterator()
next(persons) # Open a server-side cursor
self.assertUsesCursor(persons2, num_expected=2)
def test_closed_server_side_cursor(self):
persons = Person.objects.iterator()
next(persons) # Open a server-side cursor
del persons
cursors = self.inspect_cursors()
self.assertEqual(len(cursors), 0)
def test_server_side_cursors_setting(self):
with self.override_db_setting(DISABLE_SERVER_SIDE_CURSORS=False):
persons = Person.objects.iterator()
self.assertUsesCursor(persons)
del persons # Close server-side cursor
with self.override_db_setting(DISABLE_SERVER_SIDE_CURSORS=True):
self.asserNotUsesCursor(Person.objects.iterator())
|