|
1 | 1 | from django.contrib.auth.models import Group, User |
| 2 | +from django.db.models.query import Prefetch |
2 | 3 | from django.test import TestCase |
3 | 4 |
|
4 | 5 | from rest_framework import generics, serializers |
|
8 | 9 |
|
9 | 10 |
|
10 | 11 | class UserSerializer(serializers.ModelSerializer): |
| 12 | + permissions = serializers.SerializerMethodField() |
| 13 | + |
| 14 | + def get_permissions(self, obj): |
| 15 | + ret = [] |
| 16 | + for g in obj.groups.all(): |
| 17 | + ret.extend([p.pk for p in g.permissions.all()]) |
| 18 | + return ret |
| 19 | + |
11 | 20 | class Meta: |
12 | 21 | model = User |
13 | | - fields = ('id', 'username', 'email', 'groups') |
| 22 | + fields = ('id', 'username', 'email', 'groups', 'permissions') |
| 23 | + |
| 24 | + |
| 25 | +class UserRetrieveUpdate(generics.RetrieveUpdateAPIView): |
| 26 | + queryset = User.objects.exclude(username='exclude').prefetch_related( |
| 27 | + Prefetch('groups', queryset=Group.objects.exclude(name='exclude')), |
| 28 | + 'groups__permissions', |
| 29 | + ) |
| 30 | + serializer_class = UserSerializer |
14 | 31 |
|
15 | 32 |
|
16 | | -class UserUpdate(generics.UpdateAPIView): |
17 | | - queryset = User.objects.exclude(username='exclude').prefetch_related('groups') |
| 33 | +class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView): |
| 34 | + queryset = User.objects.exclude(username='exclude') |
18 | 35 | serializer_class = UserSerializer |
19 | 36 |
|
20 | 37 |
|
21 | 38 | class TestPrefetchRelatedUpdates(TestCase): |
22 | 39 | def setUp(self): |
23 | 40 | self.user = User.objects.create(username='tom', email='tom@example.com') |
24 | | - self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] |
| 41 | + self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)] |
25 | 42 | self.user.groups.set(self.groups) |
| 43 | + self.user.groups.add(Group.objects.create(name='exclude')) |
| 44 | + self.expected = { |
| 45 | + 'id': self.user.pk, |
| 46 | + 'username': 'tom', |
| 47 | + 'groups': [group.pk for group in self.groups], |
| 48 | + 'email': 'tom@example.com', |
| 49 | + 'permissions': [], |
| 50 | + } |
| 51 | + self.view = UserRetrieveUpdate.as_view() |
26 | 52 |
|
27 | 53 | def test_prefetch_related_updates(self): |
28 | | - view = UserUpdate.as_view() |
29 | | - pk = self.user.pk |
30 | | - groups_pk = self.groups[0].pk |
31 | | - request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') |
32 | | - response = view(request, pk=pk) |
33 | | - assert User.objects.get(pk=pk).groups.count() == 1 |
34 | | - expected = { |
35 | | - 'id': pk, |
36 | | - 'username': 'new', |
37 | | - 'groups': [1], |
38 | | - 'email': 'tom@example.com' |
39 | | - } |
40 | | - assert response.data == expected |
| 54 | + self.groups.append(Group.objects.create(name='c')) |
| 55 | + request = factory.put( |
| 56 | + '/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json' |
| 57 | + ) |
| 58 | + self.expected['username'] = 'new' |
| 59 | + self.expected['groups'] = [group.pk for group in self.groups] |
| 60 | + response = self.view(request, pk=self.user.pk) |
| 61 | + assert User.objects.get(pk=self.user.pk).groups.count() == 12 |
| 62 | + assert response.data == self.expected |
| 63 | + # Update and fetch should get same result |
| 64 | + request = factory.get('/') |
| 65 | + response = self.view(request, pk=self.user.pk) |
| 66 | + assert response.data == self.expected |
41 | 67 |
|
42 | 68 | def test_prefetch_related_excluding_instance_from_original_queryset(self): |
43 | 69 | """ |
44 | 70 | Regression test for https://github.com/encode/django-rest-framework/issues/4661 |
45 | 71 | """ |
46 | | - view = UserUpdate.as_view() |
47 | | - pk = self.user.pk |
48 | | - groups_pk = self.groups[0].pk |
49 | | - request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') |
50 | | - response = view(request, pk=pk) |
51 | | - assert User.objects.get(pk=pk).groups.count() == 1 |
52 | | - expected = { |
53 | | - 'id': pk, |
54 | | - 'username': 'exclude', |
55 | | - 'groups': [1], |
56 | | - 'email': 'tom@example.com' |
57 | | - } |
58 | | - assert response.data == expected |
| 72 | + request = factory.put( |
| 73 | + '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json' |
| 74 | + ) |
| 75 | + response = self.view(request, pk=self.user.pk) |
| 76 | + assert User.objects.get(pk=self.user.pk).groups.count() == 2 |
| 77 | + self.expected['username'] = 'exclude' |
| 78 | + self.expected['groups'] = [self.groups[0].pk] |
| 79 | + assert response.data == self.expected |
| 80 | + |
| 81 | + def test_db_query_count(self): |
| 82 | + request = factory.put( |
| 83 | + '/', {'username': 'new'}, format='json' |
| 84 | + ) |
| 85 | + with self.assertNumQueries(7): |
| 86 | + self.view(request, pk=self.user.pk) |
| 87 | + |
| 88 | + request = factory.put( |
| 89 | + '/', {'username': 'new2'}, format='json' |
| 90 | + ) |
| 91 | + with self.assertNumQueries(16): |
| 92 | + UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk) |
0 commit comments