Skip to content

Commit 2f03f11

Browse files
committed
Add custom group admin with user selection
Signed-off-by: Keshav Priyadarshi <git@keshav.space>
1 parent 94dd104 commit 2f03f11

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

vulnerabilities/admin.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
from django import forms
1111
from django.contrib import admin
12+
from django.contrib.admin.widgets import FilteredSelectMultiple
13+
from django.contrib.auth.admin import GroupAdmin as BasicGroupAdmin
14+
from django.contrib.auth.models import Group
15+
from django.contrib.auth.models import User
1216
from django.core.validators import validate_email
1317

1418
from vulnerabilities.models import ApiUser
@@ -97,3 +101,49 @@ def get_form(self, request, obj=None, **kwargs):
97101
defaults["form"] = self.add_form
98102
defaults.update(kwargs)
99103
return super().get_form(request, obj, **defaults)
104+
105+
106+
class GroupWithUsersForm(forms.ModelForm):
107+
users = forms.ModelMultipleChoiceField(
108+
queryset=User.objects.all(),
109+
required=False,
110+
widget=FilteredSelectMultiple("Users", is_stacked=False),
111+
label="Users",
112+
)
113+
114+
class Meta:
115+
model = Group
116+
fields = "__all__"
117+
118+
def __init__(self, *args, **kwargs):
119+
super().__init__(*args, **kwargs)
120+
self.fields["users"].label_from_instance = lambda user: (
121+
f"{user.username} | {user.email}" if user.email else user.username
122+
)
123+
if self.instance.pk:
124+
self.fields["users"].initial = self.instance.user_set.all()
125+
126+
def save(self, commit=True):
127+
group = super().save(commit=commit)
128+
self.save_m2m()
129+
group.user_set.set(self.cleaned_data["users"])
130+
return group
131+
132+
133+
admin.site.unregister(Group)
134+
135+
136+
@admin.register(Group)
137+
class GroupAdmin(admin.ModelAdmin):
138+
form = GroupWithUsersForm
139+
search_fields = ("name",)
140+
ordering = ("name",)
141+
filter_horizontal = ("permissions",)
142+
143+
def formfield_for_manytomany(self, db_field, request=None, **kwargs):
144+
if db_field.name == "permissions":
145+
qs = kwargs.get("queryset", db_field.remote_field.model.objects)
146+
# Avoid a major performance hit resolving permission names which
147+
# triggers a content_type load:
148+
kwargs["queryset"] = qs.select_related("content_type")
149+
return super().formfield_for_manytomany(db_field, request=request, **kwargs)

0 commit comments

Comments
 (0)