|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific |
12 | 12 | # language governing permissions and limitations under the License. |
13 | 13 | """Placeholder docstring""" |
14 | | -from __future__ import absolute_import |
| 14 | +from __future__ import absolute_import, annotations |
15 | 15 |
|
16 | 16 | import logging |
17 | | -import os |
18 | 17 | import platform |
19 | 18 | from datetime import datetime |
| 19 | +from typing import Dict |
20 | 20 |
|
21 | 21 | import boto3 |
22 | 22 | from botocore.exceptions import ClientError |
| 23 | +import jsonschema |
23 | 24 |
|
24 | 25 | from sagemaker.config import ( |
25 | | - load_sagemaker_config, |
26 | | - validate_sagemaker_config, |
| 26 | + SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA, |
27 | 27 | SESSION_DEFAULT_S3_BUCKET_PATH, |
28 | 28 | SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, |
| 29 | + load_local_mode_config, |
| 30 | + load_sagemaker_config, |
| 31 | + validate_sagemaker_config, |
29 | 32 | ) |
30 | 33 | from sagemaker.local.image import _SageMakerContainer |
31 | 34 | from sagemaker.local.utils import get_docker_host |
@@ -83,7 +86,7 @@ def create_processing_job( |
83 | 86 | Environment=None, |
84 | 87 | ProcessingInputs=None, |
85 | 88 | ProcessingOutputConfig=None, |
86 | | - **kwargs |
| 89 | + **kwargs, |
87 | 90 | ): |
88 | 91 | """Creates a processing job in Local Mode |
89 | 92 |
|
@@ -165,7 +168,7 @@ def create_training_job( |
165 | 168 | ResourceConfig, |
166 | 169 | InputDataConfig=None, |
167 | 170 | Environment=None, |
168 | | - **kwargs |
| 171 | + **kwargs, |
169 | 172 | ): |
170 | 173 | """Create a training job in Local Mode. |
171 | 174 |
|
@@ -230,7 +233,7 @@ def create_transform_job( |
230 | 233 | TransformInput, |
231 | 234 | TransformOutput, |
232 | 235 | TransformResources, |
233 | | - **kwargs |
| 236 | + **kwargs, |
234 | 237 | ): |
235 | 238 | """Create the transform job. |
236 | 239 |
|
@@ -537,7 +540,21 @@ def __init__(self, config=None): |
537 | 540 | self.http = urllib3.PoolManager() |
538 | 541 | self.serving_port = 8080 |
539 | 542 | self.config = config |
540 | | - self.serving_port = get_config_value("local.serving_port", config) or 8080 |
| 543 | + |
| 544 | + @property |
| 545 | + def config(self) -> dict: |
| 546 | + """Local config getter""" |
| 547 | + return self._config |
| 548 | + |
| 549 | + @config.setter |
| 550 | + def config(self, value: dict): |
| 551 | + """Local config setter, this method also updates the `serving_port` attribute. |
| 552 | +
|
| 553 | + Args: |
| 554 | + value (dict): the new config value |
| 555 | + """ |
| 556 | + self._config = value |
| 557 | + self.serving_port = get_config_value("local.serving_port", self._config) or 8080 |
541 | 558 |
|
542 | 559 | def invoke_endpoint( |
543 | 560 | self, |
@@ -686,6 +703,7 @@ def _initialize( |
686 | 703 |
|
687 | 704 | self.sagemaker_client = LocalSagemakerClient(self) |
688 | 705 | self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) |
| 706 | + |
689 | 707 | self.local_mode = True |
690 | 708 | sagemaker_config = kwargs.get("sagemaker_config", None) |
691 | 709 | if sagemaker_config: |
@@ -726,17 +744,26 @@ def _initialize( |
726 | 744 | sagemaker_session=self, |
727 | 745 | ) |
728 | 746 |
|
729 | | - local_mode_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml") |
730 | | - if os.path.exists(local_mode_config_file): |
| 747 | + self.config = load_local_mode_config() |
| 748 | + if self._disable_local_code and self.config and "local" in self.config: |
| 749 | + self.config["local"]["local_code"] = False |
| 750 | + |
| 751 | + @Session.config.setter |
| 752 | + def config(self, value: Dict | None): |
| 753 | + """Setter of the local mode config""" |
| 754 | + if value is not None: |
731 | 755 | try: |
732 | | - import yaml |
733 | | - except ImportError as e: |
734 | | - logger.error(_module_import_error("yaml", "Local mode", "local")) |
| 756 | + jsonschema.validate(value, SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA) |
| 757 | + except jsonschema.ValidationError as e: |
| 758 | + logger.error("Failed to validate the local mode config") |
735 | 759 | raise e |
| 760 | + self._config = value |
| 761 | + else: |
| 762 | + self._config = value |
736 | 763 |
|
737 | | - self.config = yaml.safe_load(open(local_mode_config_file, "r")) |
738 | | - if self._disable_local_code and "local" in self.config: |
739 | | - self.config["local"]["local_code"] = False |
| 764 | + # update the runtime client on config changed |
| 765 | + if getattr(self, "sagemaker_runtime_client", None): |
| 766 | + self.sagemaker_runtime_client.config = self._config |
740 | 767 |
|
741 | 768 | def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"): |
742 | 769 | """A no-op method meant to override the sagemaker client. |
|
0 commit comments