File: test_prefetch_related.py

package info (click to toggle)
djangorestframework 3.14.0-2%2Bdeb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 16,984 kB
  • sloc: python: 27,822; javascript: 25,191; makefile: 26; sh: 6
file content (58 lines) | stat: -rw-r--r-- 2,016 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from django.contrib.auth.models import Group, User
from django.test import TestCase

from rest_framework import generics, serializers
from rest_framework.test import APIRequestFactory

factory = APIRequestFactory()


class UserSerializer(serializers.ModelSerializer):
    class Meta:
        model = User
        fields = ('id', 'username', 'email', 'groups')


class UserUpdate(generics.UpdateAPIView):
    queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
    serializer_class = UserSerializer


class TestPrefetchRelatedUpdates(TestCase):
    def setUp(self):
        self.user = User.objects.create(username='tom', email='tom@example.com')
        self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
        self.user.groups.set(self.groups)

    def test_prefetch_related_updates(self):
        view = UserUpdate.as_view()
        pk = self.user.pk
        groups_pk = self.groups[0].pk
        request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
        response = view(request, pk=pk)
        assert User.objects.get(pk=pk).groups.count() == 1
        expected = {
            'id': pk,
            'username': 'new',
            'groups': [1],
            'email': 'tom@example.com'
        }
        assert response.data == expected

    def test_prefetch_related_excluding_instance_from_original_queryset(self):
        """
        Regression test for https://github.com/encode/django-rest-framework/issues/4661
        """
        view = UserUpdate.as_view()
        pk = self.user.pk
        groups_pk = self.groups[0].pk
        request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
        response = view(request, pk=pk)
        assert User.objects.get(pk=pk).groups.count() == 1
        expected = {
            'id': pk,
            'username': 'exclude',
            'groups': [1],
            'email': 'tom@example.com'
        }
        assert response.data == expected