File tree Expand file tree Collapse file tree 4 files changed +26
-3
lines changed
tests/unit/sagemaker/jumpstart Expand file tree Collapse file tree 4 files changed +26
-3
lines changed Original file line number Diff line number Diff line change @@ -746,6 +746,7 @@ see `Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html
746746.. code :: python
747747
748748 from sagemaker.model import Model
749+ from sagemaker.predictor import Predictor
749750 from sagemaker.session import Session
750751
751752 # Create the SageMaker model instance
@@ -755,6 +756,7 @@ see `Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html
755756 source_dir = script_uri,
756757 entry_point = " inference.py" ,
757758 role = Session().get_caller_identity_arn(),
759+ predictor_cls = Predictor,
758760 )
759761
760762 Save the output from deploying the model to a variable named
@@ -766,12 +768,9 @@ Deployment may take about 5 minutes.
766768
767769.. code :: python
768770
769- from sagemaker.predictor import Predictor
770-
771771 predictor = model.deploy(
772772 initial_instance_count = instance_count,
773773 instance_type = instance_type,
774- predictor_cls = Predictor
775774 )
776775
777776 Because ``catboost `` and ``lightgbm `` rely on the PyTorch Deep Learning Containers
Original file line number Diff line number Diff line change 122122TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
123123
124124SUPPORTED_JUMPSTART_SCOPES = set (scope .value for scope in JumpStartScriptScope )
125+
126+ ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
Original file line number Diff line number Diff line change 1313"""This module contains utilities related to SageMaker JumpStart."""
1414from __future__ import absolute_import
1515import logging
16+ import os
1617from typing import Dict , List , Optional
1718from urllib .parse import urlparse
1819from packaging .version import Version
@@ -60,6 +61,14 @@ def get_jumpstart_content_bucket(region: str) -> str:
6061 Raises:
6162 RuntimeError: If JumpStart is not launched in ``region``.
6263 """
64+
65+ if (
66+ constants .ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os .environ
67+ and len (os .environ [constants .ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE ]) > 0
68+ ):
69+ bucket_override = os .environ [constants .ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE ]
70+ LOGGER .info ("Using JumpStart bucket override: '%s'" , bucket_override )
71+ return bucket_override
6372 try :
6473 return constants .JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT [region ].content_bucket
6574 except KeyError :
Original file line number Diff line number Diff line change 1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
14+ import os
1415from mock .mock import Mock , patch
1516import pytest
1617import random
1718from sagemaker .jumpstart import utils
1819from sagemaker .jumpstart .constants import (
20+ ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE ,
1921 JUMPSTART_BUCKET_NAME_SET ,
2022 JUMPSTART_REGION_NAME_SET ,
2123 JumpStartScriptScope ,
@@ -40,6 +42,17 @@ def test_get_jumpstart_content_bucket():
4042 utils .get_jumpstart_content_bucket (bad_region )
4143
4244
45+ def test_get_jumpstart_content_bucket_override ():
46+ with patch .dict (os .environ , {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE : "some-val" }):
47+ with patch ("logging.Logger.info" ) as mocked_info_log :
48+ random_region = "random_region"
49+ assert "some-val" == utils .get_jumpstart_content_bucket (random_region )
50+ mocked_info_log .assert_called_once_with (
51+ "Using JumpStart bucket override: '%s'" ,
52+ "some-val" ,
53+ )
54+
55+
4356def test_get_jumpstart_launched_regions_message ():
4457
4558 with patch ("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET" , {}):
You can’t perform that action at this time.
0 commit comments