Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions python/fedml/api/modules/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,29 @@ def __init__(self, data: dict):
class DataType(Enum):
FILE = "file"
DIRECTORY = "directory"
BYTE = "byte"
INVALID = "invalid"



# Todo (alaydshah): Store service name in metadata
# Todo (alaydshah): If data already exists, don't upload again. Instead suggest to use update command
# Todo (bhargav) : Discuss and remove the service variable. Maybe needed sometime later.
def upload(data_path, api_key, name, description, tag_list, service, show_progress, out_progress_to_err, progress_desc,
metadata) -> FedMLResponse:
metadata, byte_data_flag=False, byte_data=None) -> FedMLResponse:
api_key = authenticate(api_key)

user_id, message = _get_user_id_from_api_key(api_key)

if user_id is None:
return FedMLResponse(code=ResponseCode.FAILURE, message=message)

data_type = _get_data_type(data_path)
data_type = _get_data_type(data_path, byte_data_flag)

if(data_type == DataType.INVALID):
if data_type == DataType.INVALID:
return FedMLResponse(code=ResponseCode.FAILURE,message="Invalid data path")

if(data_type == DataType.DIRECTORY):
if data_type == DataType.DIRECTORY:
to_upload_path, message = _archive_data(data_path)
name = os.path.splitext(os.path.basename(to_upload_path))[0] if name is None else name
file_name = name + ".zip"
Expand All @@ -67,18 +70,24 @@ def upload(data_path, api_key, name, description, tag_list, service, show_progre

file_name = name

if not to_upload_path:
if not to_upload_path and not byte_data_flag:
return FedMLResponse(code=ResponseCode.FAILURE, message=message)

#TODO(bhargav191098) - Better done on the backend. Remove and pass file_name once completed on backend.
dest_path = os.path.join(user_id, file_name)
file_size = os.path.getsize(to_upload_path)
max_chunk_size = 20 * 1024 * 1024

if byte_data_flag:
file_size = sum(len(chunk) for chunk in get_chunks_from_byte_data(byte_data, max_chunk_size))

file_uploaded_url, message = _upload_multipart(api_key, dest_path, to_upload_path, show_progress,
else:
file_size = os.path.getsize(to_upload_path)

file_uploaded_url, message = _upload_multipart(api_key, dest_path, file_size, max_chunk_size, to_upload_path, show_progress,
out_progress_to_err,
progress_desc, metadata)
progress_desc, metadata, byte_data_flag, byte_data)

if(data_type == "dir"):
if data_type == "dir":
os.remove(to_upload_path)
if not file_uploaded_url:
return FedMLResponse(code=ResponseCode.FAILURE, message=f"Failed to upload file: {to_upload_path}")
Expand Down Expand Up @@ -262,6 +271,13 @@ def get_chunks(file_path, chunk_size):
break
yield chunk

def get_chunks_from_byte_data(byte_data, chunk_size):
while True:
chunk = byte_data.read(chunk_size)
if not chunk:
break
yield chunk


def _get_presigned_url(api_key, request_url, file_name, part_number=None):
cert_path = MLOpsConfigs.get_cert_path_with_version()
Expand All @@ -287,7 +303,7 @@ def _upload_part(url,part_data,session):
return response


def _upload_chunk(presigned_url, chunk, part, pbar=None, max_retries=20,session=None):
def _upload_chunk(presigned_url, chunk, part, pbar=None, max_retries=20,session=None, byte_data_flag= False):
for retry_attempt in range(max_retries):
try:
response = _upload_part(presigned_url,chunk,session)
Expand All @@ -297,11 +313,12 @@ def _upload_chunk(presigned_url, chunk, part, pbar=None, max_retries=20,session=
else:
raise requests.exceptions.RequestException

if(pbar is not None):
pbar.update(chunk.__sizeof__())
if pbar is not None:
pbar.update(len(chunk))
return {'etag': response.headers['ETag'], 'partNumber': part}
raise requests.exceptions.RequestException


def _process_post_response(response):
if response.status_code != 200:
message = (f"Failed to complete multipart upload with status code = {response.status_code}, "
Expand Down Expand Up @@ -345,14 +362,10 @@ def _complete_multipart_upload(api_key, file_key, part_info, upload_id):
return _process_post_response(complete_multipart_response)


def _upload_multipart(api_key: str, file_key, archive_path, show_progress, out_progress_to_err,
progress_desc_text, metadata):
def _upload_multipart(api_key: str, file_key, file_size, max_chunk_size, archive_path, show_progress, out_progress_to_err,
progress_desc_text, metadata, byte_data_flag, byte_data):
request_url = ServerConstants.get_presigned_multi_part_url()

file_size = os.path.getsize(archive_path)

max_chunk_size = 20 * 1024 * 1024

num_chunks = _get_num_chunks(file_size, max_chunk_size)

upload_id = ""
Expand All @@ -379,8 +392,12 @@ def _upload_multipart(api_key: str, file_key, archive_path, show_progress, out_p
upload_id = data['uploadId']
presigned_urls = data['urls']

parts = []
chunks = get_chunks(archive_path, max_chunk_size)
if byte_data_flag:
byte_data.seek(0)
chunks = get_chunks_from_byte_data(byte_data, max_chunk_size)
else:
chunks = get_chunks(archive_path, max_chunk_size)

part_info = []
chunk_count = 0
successful_chunks = 0
Expand All @@ -396,7 +413,7 @@ def _upload_multipart(api_key: str, file_key, archive_path, show_progress, out_p
if show_progress:
try:
part_data = _upload_chunk(presigned_url=presigned_url, chunk=chunk, part=part,
pbar=pbar,session=atomic_session)
pbar=pbar,session=atomic_session, byte_data_flag = byte_data_flag)
part_info.append(part_data)
successful_chunks += 1
except Exception as e:
Expand Down Expand Up @@ -474,8 +491,11 @@ def _get_storage_service(service):
else:
raise NotImplementedError(f"Service {service} not implemented")

def _get_data_type(data_path):
if os.path.isdir(data_path):

def _get_data_type(data_path, byte_data_flag):
if byte_data_flag:
return DataType.BYTE
elif os.path.isdir(data_path):
return DataType.DIRECTORY
elif os.path.isfile(data_path):
return DataType.FILE
Expand Down