@@ -358,7 +358,7 @@ def create_tar_file(source_files, target=None):
358358
359359
360360@contextlib .contextmanager
361- def _tmpdir (suffix = "" , prefix = "tmp" ):
361+ def _tmpdir (suffix = "" , prefix = "tmp" , directory = None ):
362362 """Create a temporary directory with a context manager.
363363
364364 The file is deleted when the context exits.
@@ -369,11 +369,18 @@ def _tmpdir(suffix="", prefix="tmp"):
369369 suffix, otherwise there will be no suffix.
370370 prefix (str): If prefix is specified, the file name will begin with that
371371 prefix; otherwise, a default prefix is used.
372+ directory (str): If a directory is specified, the file will be downloaded
373+ in this directory; otherwise, a default directory is used.
372374
373375 Returns:
374376 str: path to the directory
375377 """
376- tmp = tempfile .mkdtemp (suffix = suffix , prefix = prefix , dir = None )
378+ if directory is not None and not (os .path .exists (directory ) and os .path .isdir (directory )):
379+ raise ValueError (
380+ "Inputted directory for storing newly generated temporary "
381+ f"directory does not exist: '{ directory } '"
382+ )
383+ tmp = tempfile .mkdtemp (suffix = suffix , prefix = prefix , dir = directory )
377384 yield tmp
378385 shutil .rmtree (tmp )
379386
@@ -427,7 +434,13 @@ def repack_model(
427434 """
428435 dependencies = dependencies or []
429436
430- with _tmpdir () as tmp :
437+ local_download_dir = (
438+ None
439+ if sagemaker_session .settings is None
440+ or sagemaker_session .settings .local_download_dir is None
441+ else sagemaker_session .settings .local_download_dir
442+ )
443+ with _tmpdir (directory = local_download_dir ) as tmp :
431444 model_dir = _extract_model (model_uri , sagemaker_session , tmp )
432445
433446 _create_or_update_code_dir (
0 commit comments