diff --git a/pyas2/admin.py b/pyas2/admin.py index f9f1e37..2bd293a 100644 --- a/pyas2/admin.py +++ b/pyas2/admin.py @@ -72,6 +72,10 @@ class PartnerAdmin(admin.ModelAdmin): "mdn_mode", ] list_filter = ("name", "as2_name") + list_select_related = ( + "encryption_cert", + "signature_cert", + ) fieldsets = ( ( None, @@ -169,6 +173,12 @@ class MessageAdmin(admin.ModelAdmin): "mdn_url", ] + list_select_related = ( + "partner", + "organization", + "mdn", + ) + @staticmethod def mdn_url(obj): """Return the URL to the related MDN if present for the message.""" @@ -212,6 +222,7 @@ class MdnAdmin(admin.ModelAdmin): ) list_display = ("mdn_id", "message", "timestamp", "status") list_filter = ("status",) + list_select_related = ("message",) def has_add_permission(self, request): return False diff --git a/pyas2/management/commands/sendas2message.py b/pyas2/management/commands/sendas2message.py index cd04c41..165279d 100644 --- a/pyas2/management/commands/sendas2message.py +++ b/pyas2/management/commands/sendas2message.py @@ -66,6 +66,7 @@ def handle(self, *args, **options): content_type=partner.content_type, disposition_notification_to=org.email_address or "no-reply@pyas2.com", ) + message, _ = Message.objects.create_from_as2message( as2message=as2message, payload=payload, @@ -73,6 +74,9 @@ def handle(self, *args, **options): direction="OUT", status="P", ) + + message.organization = org + message.partner = partner message.send_message(as2message.headers, as2message.content) # Delete original file if option is set diff --git a/pyas2/models.py b/pyas2/models.py index 7463cdf..8348d58 100644 --- a/pyas2/models.py +++ b/pyas2/models.py @@ -342,9 +342,12 @@ def create_from_as2message( if not filename: filename = f"{uuid4()}.msg" message.headers.save( - name=f"{filename}.header", content=ContentFile(as2message.headers_str) + name=f"{filename}.header", + content=ContentFile(as2message.headers_str), + save=False, ) - message.payload.save(name=filename, content=ContentFile(payload)) + message.payload.save(name=filename, content=ContentFile(payload), save=False) + message.save() # Save the payload to the inbox folder full_filename = None @@ -460,6 +463,13 @@ def status_icon(self): def send_message(self, header, payload): """Send the message to the partner""" + + if self.organization_id and self.partner_id: + self.organization = self.organization or Organization.objects.get( + id=self.organization_id + ) + self.partner = self.partner or Partner.objects.get(id=self.partner_id) + logger.info( f'Sending message {self.message_id} from organization "{self.organization}" ' f'to partner "{self.partner}".' diff --git a/pyas2/tests/test_basic.py b/pyas2/tests/test_basic.py index 7cb259c..3b20ead 100644 --- a/pyas2/tests/test_basic.py +++ b/pyas2/tests/test_basic.py @@ -2,22 +2,23 @@ from email.parser import HeaderParser from unittest import mock -from django.test import TestCase, Client +from django.db import connection +from django.test import Client, TestCase +from django.test.utils import CaptureQueriesContext +from pyas2lib.as2 import Message as As2Message from requests import Response from requests.exceptions import RequestException from pyas2.models import ( - PrivateKey, - PublicCertificate, + Mdn, + Message, Organization, Partner, - Message, - Mdn, + PrivateKey, + PublicCertificate, ) from pyas2.tests import TEST_DIR -from pyas2lib.as2 import Message as As2Message - class BasicServerClientTestCase(TestCase): """Test cases for the AS2 server and client. @@ -534,6 +535,29 @@ def testEncryptSignMessageAsyncSignMdn(self, mock_request): mock_request.side_effect = RequestException() out_message.mdn.send_async_mdn() + def testNumberOfQueries(self): + """Testing against the number of queries executed""" + + # Create the partner with appropriate settings for this case + + partner = Partner.objects.create( + name="AS2 Server", + as2_name="as2server", + target_url="http://localhost:8080/pyas2/as2receive", + mdn=False, + ) + + with CaptureQueriesContext(connection) as queries: + in_message = self.build_and_send(partner) + + # Remove the transaction related queries + filtered_queries = [ + query for query in queries if "SAVEPOINT" not in query["sql"] + ] + + # Number of queries should be 13 + self.assertEqual(len(filtered_queries), 13) + @mock.patch("requests.post") def build_and_send(self, partner, mock_request): # Build and send the message to server diff --git a/pyas2/tests/test_commands.py b/pyas2/tests/test_commands.py index eb4006d..5301448 100644 --- a/pyas2/tests/test_commands.py +++ b/pyas2/tests/test_commands.py @@ -7,11 +7,13 @@ from django.conf import settings from django.core import management from django.core.files.base import ContentFile +from django.db import connection +from django.test.utils import CaptureQueriesContext from pyas2 import settings as app_settings -from pyas2.models import As2Message, Message, Mdn -from pyas2.tests import TEST_DIR from pyas2.management.commands.sendas2bulk import Command as SendBulkCommand +from pyas2.models import As2Message, Mdn, Message +from pyas2.tests import TEST_DIR @pytest.mark.django_db @@ -78,14 +80,23 @@ def test_sendmessage_command(mocker, organization, partner): mocked_delete = mocker.patch( "pyas2.management.commands.sendas2message.default_storage.delete" ) - management.call_command( - "sendas2message", - organization.as2_name, - partner.as2_name, - test_message, - delete=True, - ) + + with CaptureQueriesContext(connection) as queries: + + management.call_command( + "sendas2message", + organization.as2_name, + partner.as2_name, + test_message, + delete=True, + ) + + filtered_queries = [ + query for query in queries if "SAVEPOINT" not in query["sql"] + ] + assert mocked_delete.call_count == 1 + assert len(filtered_queries) == 6 @pytest.mark.django_db diff --git a/pyas2/views.py b/pyas2/views.py index c205871..cdf38f0 100644 --- a/pyas2/views.py +++ b/pyas2/views.py @@ -56,7 +56,11 @@ def check_message_exists(message_id, partner_id): @staticmethod def find_organization(org_id): """Find the org using the As2 Id and return its pyas2 type""" - org = Organization.objects.filter(as2_name=org_id).first() + org = ( + Organization.objects.select_related("encryption_key", "signature_key") + .filter(as2_name=org_id) + .first() + ) if org: return org.as2org return None @@ -64,7 +68,11 @@ def find_organization(org_id): @staticmethod def find_partner(partner_id): """Find the partner using the As2 Id and return its pyas2 type""" - partner = Partner.objects.filter(as2_name=partner_id).first() + partner = ( + Partner.objects.select_related("encryption_cert", "signature_cert") + .filter(as2_name=partner_id) + .first() + ) if partner: return partner.as2partner return None @@ -97,7 +105,7 @@ def post(self, request, *args, **kwargs): status, detailed_status = as2mdn.parse(request_body, self.find_message) if not detailed_status == "mdn-not-found": - message = Message.objects.get( + message = Message.objects.select_related("organization", "partner").get( message_id=as2mdn.orig_message_id, direction="OUT" ) logger.info(