diff --git a/django_sns_view/__init__.py b/django_sns_view/__init__.py index a9d1134..b2336be 100644 --- a/django_sns_view/__init__.py +++ b/django_sns_view/__init__.py @@ -1,2 +1,2 @@ # -*- coding: utf-8 -*- -__version__ = '0.1.2-sl.1' # pragma: no cover \ No newline at end of file +__version__ = '0.1.3' # pragma: no cover \ No newline at end of file diff --git a/django_sns_view/tests/test_settings.py b/django_sns_view/tests/test_settings.py index 71f52cc..ce51e66 100644 --- a/django_sns_view/tests/test_settings.py +++ b/django_sns_view/tests/test_settings.py @@ -38,15 +38,6 @@ }] EXTERNAL_APPS = [ - 'django.contrib.admin', - 'django.contrib.admindocs', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.messages', - 'django.contrib.sessions', - 'django.contrib.staticfiles', - 'django.contrib.sitemaps', - 'django.contrib.sites', ] INTERNAL_APPS = [ @@ -64,10 +55,10 @@ SECRET_KEY = 'foobar' -TEST_RUNNER = 'django_nose.NoseTestSuiteRunner' +#TEST_RUNNER = 'django_nose.NoseTestSuiteRunner' -NOSE_ARGS = [ - '--with-xunit', - '--nologcapture', - '--cover-package=django_sns_view', -] \ No newline at end of file +# NOSE_ARGS = [ +# '--with-xunit', +# '--nologcapture', +# '--cover-package=django_sns_view', +# ] \ No newline at end of file diff --git a/django_sns_view/tests/urls.py b/django_sns_view/tests/urls.py new file mode 100644 index 0000000..2fe9c7c --- /dev/null +++ b/django_sns_view/tests/urls.py @@ -0,0 +1,3 @@ + + +urlpatterns = [] \ No newline at end of file diff --git a/django_sns_view/tests/views.py b/django_sns_view/tests/views.py index e7cd2ae..0ef34bb 100644 --- a/django_sns_view/tests/views.py +++ b/django_sns_view/tests/views.py @@ -134,3 +134,30 @@ def test_handle_message_sucessfully_called(self, mock): json.loads(json.dumps(SNS_NOTIFICATION)) ) self.assertEqual(response.status_code, 200) + + @override_settings(AWS_ACCOUNT_ID="919599206538") + @patch('django_sns_view.views.confirm_subscription') + @patch.object(SNSEndpoint, 'handle_message') + def test_subscribe_from_correct_account(self, mock, mock_confirm): + """ + Test that subscriptions from the correct account work + with AWS_ACCOUNT_ID set + """ + self.request.META['HTTP_X_AMZ_SNS_MESSAGE_TYPE'] = \ + 'SubscriptionConfirmation' + self.request._body = json.dumps(self.sns_confirmation) + self.endpoint(self.request) + self.assertTrue(mock_confirm.called) + + @override_settings(AWS_ACCOUNT_ID="1010101010") + @patch.object(SNSEndpoint, 'handle_message') + def test_subscribe_from_another_account(self, mock): + """ + Test that subscriptions from another account DO NOT work + if AWS_ACCOUNT_ID is set + """ + self.request.META['HTTP_X_AMZ_SNS_MESSAGE_TYPE'] = \ + 'SubscriptionConfirmation' + self.request._body = json.dumps(self.sns_confirmation) + response = self.endpoint(self.request) + self.assertEqual(response.status_code, 400) diff --git a/django_sns_view/views.py b/django_sns_view/views.py index 1d6de26..c5b8857 100644 --- a/django_sns_view/views.py +++ b/django_sns_view/views.py @@ -13,6 +13,7 @@ from django.views.generic import View from django.views.decorators.csrf import csrf_exempt from django.utils.decorators import method_decorator +from django.conf import settings from django_sns_view.utils import confirm_subscription, verify_notification @@ -40,6 +41,24 @@ def handle_message(self, message, notification): """ raise NotImplementedError + def should_confirm_subscription(self, payload): + """ + Determine if the subscription should be confirmed. + By default, we confirm all subscriptions. + If settings has an AWS_ACCOUNT_ID key, we only confirm subscriptions from that account. + + This behavior can be overridden by subclassing and overriding this method. + """ + if hasattr(settings, 'AWS_ACCOUNT_ID'): + arn = payload['TopicArn'].split(':')[4] + print(arn) + if arn == settings.AWS_ACCOUNT_ID: + return True + else: + logger.warning("Recieved subscription confirmation from account %s, but only accepting from account %s", arn, settings.AWS_ACCOUNT_ID) + return False + return True + def post(self, request): """ Validate and handle an SNS message. @@ -97,6 +116,8 @@ def post(self, request): return HttpResponseBadRequest('Invalid Notification Type') if message_type == 'SubscriptionConfirmation': + if not self.should_confirm_subscription(payload): + return HttpResponseBadRequest("Subscription Denied") return confirm_subscription(payload) elif message_type == 'UnsubscribeConfirmation': # Don't handle unsubscribe notification here, just remove diff --git a/test_requirements.txt b/test_requirements.txt index 7e006f6..6698cbc 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,7 +1,5 @@ tox==2.9.1 -nose requests -django-nose coverage mock pyopenssl>=0.13.1