11from __future__ import annotations
22
3+ import io
34from typing import TYPE_CHECKING
45
5- from boto .exception import S3ResponseError
6- from boto .s3 import connect_to_region , connection
7- from boto .s3 .key import Key
6+ import boto3
7+ from botocore .exceptions import ClientError
88
99from . import logger , settings
10+ from .swf .mapper .exceptions import extract_error_code
1011
1112if TYPE_CHECKING :
1213 from typing import Optional , Tuple # NOQA
1314
14- from boto .s3 .bucket import Bucket # NOQA
15- from boto .s3 .bucketlistresultset import BucketListResultSet # NOQA
15+ from mypy_boto3_s3 .service_resource import Bucket , ObjectSummary # NOQA
1616
1717BUCKET_CACHE = {}
1818BUCKET_LOCATIONS_CACHE = {}
1919
2020
21- def get_connection (host_or_region : str ) -> connection .S3Connection :
21+ def get_client () -> boto3 .session .Session .client :
22+ return boto3 .session .Session ().client ("s3" )
23+
24+
25+ def get_resource (host_or_region : str ) -> boto3 .session .Session .resource :
2226 # first case: we got a valid DNS (host)
2327 if "." in host_or_region :
24- return connection . S3Connection ( host = host_or_region )
28+ return boto3 . resource ( "s3" , endpoint_url = f"https:// { host_or_region } " )
2529
2630 # second case: we got a region
27- return connect_to_region ( host_or_region )
31+ return boto3 . resource ( "s3" , region_name = host_or_region )
2832
2933
3034def sanitize_bucket_and_host (bucket : str ) -> tuple [str , str ]:
@@ -48,20 +52,18 @@ def sanitize_bucket_and_host(bucket: str) -> tuple[str, str]:
4852
4953 # second case: we got a bucket name, we need to figure out which region it's in
5054 try :
51- conn0 = connection .S3Connection ()
52- bucket_obj = conn0 .get_bucket (bucket , validate = False )
53-
54- # get_location() returns a region or an empty string for us-east-1,
55- # historically named "US Standard" in some places. Maybe other S3
56- # calls support an empty string as region, but I prefer to be
57- # explicit here.
58- location = bucket_obj .get_location () or "us-east-1"
55+ # get_bucket_location() returns a region or an empty string for us-east-1,
56+ # historically named "US Standard" in some places. Maybe other S3 calls
57+ # support an empty string as region, but I prefer to be explicit here.
58+ location = get_client ().get_bucket_location (Bucket = bucket )["LocationConstraint" ] or "us-east-1"
5959
6060 # save location for later use
6161 BUCKET_LOCATIONS_CACHE [bucket ] = location
62- except S3ResponseError as e :
63- if e .error_code == "AccessDenied" :
62+ except ClientError as e :
63+ error_code = extract_error_code (e )
64+ if error_code == "AccessDenied" :
6465 # probably not allowed to perform GetBucketLocation on this bucket
66+ # TODO: consider raising instead? who forbids GetBucketLocation anyway?
6567 logger .warning (f"Access denied while trying to get location of bucket { bucket } " )
6668 location = ""
6769 else :
@@ -74,45 +76,47 @@ def sanitize_bucket_and_host(bucket: str) -> tuple[str, str]:
7476 return bucket , location
7577
7678
77- def get_bucket (bucket_name : str ) -> Bucket :
79+ def get_bucket (bucket_name : str ) -> " Bucket" :
7880 bucket_name , location = sanitize_bucket_and_host (bucket_name )
79- conn = get_connection (location )
81+ s3 = get_resource (location )
8082 if bucket_name not in BUCKET_CACHE :
81- bucket = conn . get_bucket (bucket_name , validate = False )
83+ bucket = s3 . Bucket (bucket_name )
8284 BUCKET_CACHE [bucket_name ] = bucket
8385 return BUCKET_CACHE [bucket_name ]
8486
8587
8688def pull (bucket : str , path : str , dest_file : str ) -> None :
87- bucket = get_bucket (bucket )
88- key = bucket .get_key (path )
89- key .get_contents_to_filename (dest_file )
89+ bucket_resource = get_bucket (bucket )
90+ bucket_resource .download_file (path , dest_file )
9091
9192
9293def pull_content (bucket : str , path : str ) -> str :
93- bucket = get_bucket (bucket )
94- key = bucket .get_key (path )
95- return key .get_contents_as_string (encoding = "utf-8" )
94+ bucket_resource = get_bucket (bucket )
95+ bytes_buffer = io .BytesIO ()
96+ bucket_resource .download_fileobj (path , bytes_buffer )
97+ return bytes_buffer .getvalue ().decode ()
9698
9799
98100def push (bucket : str , path : str , src_file : str , content_type : str | None = None ) -> None :
99- bucket = get_bucket (bucket )
100- key = Key (bucket , path )
101- headers = {}
101+ bucket_resource = get_bucket (bucket )
102+ extra_args = {}
102103 if content_type :
103- headers ["content_type" ] = content_type
104- key .set_contents_from_filename (src_file , headers = headers , encrypt_key = settings .SIMPLEFLOW_S3_SSE )
104+ extra_args ["ContentType" ] = content_type
105+ if settings .SIMPLEFLOW_S3_SSE :
106+ extra_args ["ServerSideEncryption" ] = "AES256"
107+ bucket_resource .upload_file (src_file , path , ExtraArgs = extra_args )
105108
106109
107110def push_content (bucket : str , path : str , content : str , content_type : str | None = None ) -> None :
108- bucket = get_bucket (bucket )
109- key = Key (bucket , path )
110- headers = {}
111+ bucket_resource = get_bucket (bucket )
112+ extra_args = {}
111113 if content_type :
112- headers ["content_type" ] = content_type
113- key .set_contents_from_string (content , headers = headers , encrypt_key = settings .SIMPLEFLOW_S3_SSE )
114+ extra_args ["ContentType" ] = content_type
115+ if settings .SIMPLEFLOW_S3_SSE :
116+ extra_args ["ServerSideEncryption" ] = "AES256"
117+ bucket_resource .upload_fileobj (io .BytesIO (content .encode ()), path , ExtraArgs = extra_args )
114118
115119
116- def list_keys (bucket : str , path : str = None ) -> BucketListResultSet :
117- bucket = get_bucket (bucket )
118- return bucket . list ( path )
120+ def list_keys (bucket : str , path : str = None ) -> list [ "ObjectSummary" ] :
121+ bucket_resource = get_bucket (bucket )
122+ return [ obj for obj in bucket_resource . objects . filter ( Prefix = path or "" ). all ()]
0 commit comments