|
31 | 31 | import sagemaker.logs |
32 | 32 | from sagemaker import vpc_utils |
33 | 33 |
|
| 34 | +from sagemaker._studio import _append_project_tags |
34 | 35 | from sagemaker.deprecations import deprecated_class |
35 | 36 | from sagemaker.inputs import ShuffleConfig, TrainingInput |
36 | 37 | from sagemaker.user_agent import prepend_user_agent |
@@ -534,6 +535,7 @@ def train( # noqa: C901 |
534 | 535 | Returns: |
535 | 536 | str: ARN of the training job, if it is created. |
536 | 537 | """ |
| 538 | + tags = _append_project_tags(tags) |
537 | 539 | train_request = self._get_train_request( |
538 | 540 | input_mode=input_mode, |
539 | 541 | input_config=input_config, |
@@ -779,6 +781,7 @@ def process( |
779 | 781 | three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. |
780 | 782 | (default: ``None``) |
781 | 783 | """ |
| 784 | + tags = _append_project_tags(tags) |
782 | 785 | process_request = self._get_process_request( |
783 | 786 | inputs=inputs, |
784 | 787 | output_config=output_config, |
@@ -1019,6 +1022,7 @@ def create_monitoring_schedule( |
1019 | 1022 | "NetworkConfig" |
1020 | 1023 | ] = network_config |
1021 | 1024 |
|
| 1025 | + tags = _append_project_tags(tags) |
1022 | 1026 | if tags is not None: |
1023 | 1027 | monitoring_schedule_request["Tags"] = tags |
1024 | 1028 |
|
@@ -1527,6 +1531,8 @@ def auto_ml( |
1527 | 1531 | auto_ml_job_request["AutoMLJobObjective"] = job_objective |
1528 | 1532 | if problem_type is not None: |
1529 | 1533 | auto_ml_job_request["ProblemType"] = problem_type |
| 1534 | + |
| 1535 | + tags = _append_project_tags(tags) |
1530 | 1536 | if tags is not None: |
1531 | 1537 | auto_ml_job_request["Tags"] = tags |
1532 | 1538 |
|
@@ -1719,6 +1725,7 @@ def compile_model( |
1719 | 1725 | "CompilationJobName": job_name, |
1720 | 1726 | } |
1721 | 1727 |
|
| 1728 | + tags = _append_project_tags(tags) |
1722 | 1729 | if tags is not None: |
1723 | 1730 | compilation_job_request["Tags"] = tags |
1724 | 1731 |
|
@@ -1868,6 +1875,7 @@ def tune( # noqa: C901 |
1868 | 1875 | if warm_start_config is not None: |
1869 | 1876 | tune_request["WarmStartConfig"] = warm_start_config |
1870 | 1877 |
|
| 1878 | + tags = _append_project_tags(tags) |
1871 | 1879 | if tags is not None: |
1872 | 1880 | tune_request["Tags"] = tags |
1873 | 1881 |
|
@@ -1925,6 +1933,7 @@ def create_tuning_job( |
1925 | 1933 | if warm_start_config is not None: |
1926 | 1934 | tune_request["WarmStartConfig"] = warm_start_config |
1927 | 1935 |
|
| 1936 | + tags = _append_project_tags(tags) |
1928 | 1937 | if tags is not None: |
1929 | 1938 | tune_request["Tags"] = tags |
1930 | 1939 |
|
@@ -2315,6 +2324,7 @@ def transform( |
2315 | 2324 | job. Dictionary contains two optional keys, |
2316 | 2325 | 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. |
2317 | 2326 | """ |
| 2327 | + tags = _append_project_tags(tags) |
2318 | 2328 | transform_request = self._get_transform_request( |
2319 | 2329 | job_name=job_name, |
2320 | 2330 | model_name=model_name, |
@@ -2430,6 +2440,7 @@ def create_model( |
2430 | 2440 | Returns: |
2431 | 2441 | str: Name of the Amazon SageMaker ``Model`` created. |
2432 | 2442 | """ |
| 2443 | + tags = _append_project_tags(tags) |
2433 | 2444 | create_model_request = self._create_model_request( |
2434 | 2445 | name=name, |
2435 | 2446 | role=role, |
@@ -2754,6 +2765,7 @@ def create_endpoint_config( |
2754 | 2765 | ], |
2755 | 2766 | } |
2756 | 2767 |
|
| 2768 | + tags = _append_project_tags(tags) |
2757 | 2769 | if tags is not None: |
2758 | 2770 | request["Tags"] = tags |
2759 | 2771 |
|
@@ -2823,6 +2835,7 @@ def create_endpoint_config_from_existing( |
2823 | 2835 | request_tags = new_tags or self.list_tags( |
2824 | 2836 | existing_endpoint_config_desc["EndpointConfigArn"] |
2825 | 2837 | ) |
| 2838 | + request_tags = _append_project_tags(request_tags) |
2826 | 2839 | if request_tags: |
2827 | 2840 | request["Tags"] = request_tags |
2828 | 2841 |
|
@@ -2857,6 +2870,7 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True): |
2857 | 2870 | LOGGER.info("Creating endpoint with name %s", endpoint_name) |
2858 | 2871 |
|
2859 | 2872 | tags = tags or [] |
| 2873 | + tags = _append_project_tags(tags) |
2860 | 2874 |
|
2861 | 2875 | self.sagemaker_client.create_endpoint( |
2862 | 2876 | EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags |
@@ -3336,6 +3350,7 @@ def endpoint_from_production_variants( |
3336 | 3350 | lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name) |
3337 | 3351 | ): |
3338 | 3352 | config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} |
| 3353 | + tags = _append_project_tags(tags) |
3339 | 3354 | if tags: |
3340 | 3355 | config_options["Tags"] = tags |
3341 | 3356 | if kms_key: |
@@ -3728,6 +3743,7 @@ def create_feature_group( |
3728 | 3743 | Returns: |
3729 | 3744 | Response dict from service. |
3730 | 3745 | """ |
| 3746 | + tags = _append_project_tags(tags) |
3731 | 3747 | kwargs = dict( |
3732 | 3748 | FeatureGroupName=feature_group_name, |
3733 | 3749 | RecordIdentifierFeatureName=record_identifier_name, |
|
0 commit comments