diff --git a/sqlalchemy_media/stores/s3.py b/sqlalchemy_media/stores/s3.py index 2e95c34..55d7c8d 100644 --- a/sqlalchemy_media/stores/s3.py +++ b/sqlalchemy_media/stores/s3.py @@ -10,6 +10,7 @@ # Importing optional stuff required by S3 store try: from requests_aws4auth import AWS4Auth + from botocore.session import Session except ImportError: # pragma: no cover AWS4Auth = None @@ -25,6 +26,8 @@ class S3Store(Store): """ Store for dealing with s3 of aws + Refreshable credentials suitable for use with IAM roles + can be enabled by leaving access_key and secret_key blank .. versionadded:: 0.9.0 @@ -35,7 +38,7 @@ class S3Store(Store): """ base_url = 'https://{0}.s3.amazonaws.com' - def __init__(self, bucket: str, access_key: str, secret_key: str, + def __init__(self, bucket: str, access_key: str = None, secret_key: str = None, region: str, max_age: int = DEFAULT_MAX_AGE, prefix: str = None, base_url: str = None, cdn_url: str = None, cdn_prefix_ignore: bool = False, @@ -65,14 +68,23 @@ def __init__(self, bucket: str, access_key: str, secret_key: str, self.cdn_url = cdn_url + def _auth(self): + ensure_aws4auth() + + if self.access_key is not None and self.secret_key is not None: + return AWS4Auth(self.access_key, self.secret_key, self.region, 's3') + else: + credentials = Session().get_credentials() + return AWS4Auth(region=self.region, service='s3', + refreshable_credentials=credentials) + + def _get_s3_url(self, filename: str): return '{0}/{1}'.format(self.base_url, filename) def _upload_file(self, url: str, data: str, content_type: str, rrs: bool = False): - ensure_aws4auth() - - auth = AWS4Auth(self.access_key, self.secret_key, self.region, 's3') + auth = self._auth() if rrs: storage_class = 'REDUCED_REDUNDANCY' else: @@ -97,17 +109,15 @@ def put(self, filename: str, stream: FileLike): return len(data) def delete(self, filename: str): - ensure_aws4auth() url = self._get_s3_url(filename) - auth = AWS4Auth(self.access_key, self.secret_key, self.region, 's3') + auth = self._auth() res = requests.delete(url, auth=auth) if not 200 <= res.status_code < 300: raise S3Error(res.text) def open(self, filename: str, mode: str = 'rb') -> FileLike: - ensure_aws4auth() url = self._get_s3_url(filename) - auth = AWS4Auth(self.access_key, self.secret_key, self.region, 's3') + auth = self._auth() res = requests.get(url, auth=auth) if not 200 <= res.status_code < 300: raise S3Error(res.text)