Skip to content

Commit 8c29166

Browse files
authored
feat: support codeartifact for installing requirements.txt packages (#130)
- add a dependency on boto3 for calling codeartifact apis - update model_server to check for the presence of codeartifact (CA_* prefixed) environment variable - if env variable is present, build the authenticated endpoint index url, and add that to the pip install command, otherwise keep using pypi index closes #85
1 parent 3167082 commit 8c29166

File tree

3 files changed

+114
-3
lines changed

3 files changed

+114
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def read_version():
2929

3030
packages = setuptools.find_packages(where="src", exclude=("test",))
3131

32-
required_packages = ["numpy", "six", "psutil", "retrying>=1.3.3,<1.4", "scipy"]
32+
required_packages = ["boto3", "numpy", "six", "psutil", "retrying>=1.3.3,<1.4", "scipy"]
3333

3434
# enum is introduced in Python 3.4. Installing enum back port
3535
if sys.version_info < (3, 4):

src/sagemaker_inference/model_server.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from __future__ import absolute_import
1616

1717
import os
18+
import re
1819
import signal
1920
import subprocess
2021
import sys
2122

23+
import boto3
2224
import pkg_resources
2325
import psutil
2426
from retrying import retry
@@ -199,14 +201,67 @@ def _terminate(signo, frame): # pylint: disable=unused-argument
199201
def _install_requirements():
200202
logger.info("installing packages from requirements.txt...")
201203
pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH]
202-
204+
if os.getenv("CA_REPOSITORY_ARN"):
205+
index = _get_codeartifact_index()
206+
pip_install_cmd.append("-i")
207+
pip_install_cmd.append(index)
203208
try:
204209
subprocess.check_call(pip_install_cmd)
205210
except subprocess.CalledProcessError:
206211
logger.error("failed to install required packages, exiting")
207212
raise ValueError("failed to install required packages")
208213

209214

215+
def _get_codeartifact_index():
216+
"""
217+
Build the authenticated codeartifact index url
218+
https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
219+
https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
220+
:return: authenticated codeartifact index url
221+
"""
222+
repository_arn = os.getenv("CA_REPOSITORY_ARN")
223+
arn_regex = (
224+
"arn:(?P<partition>[^:]+):codeartifact:(?P<region>[^:]+):(?P<account>[^:]+)"
225+
":repository/(?P<domain>[^/]+)/(?P<repository>.+)"
226+
)
227+
m = re.match(arn_regex, repository_arn)
228+
if not m:
229+
raise Exception("invalid CodeArtifact repository arn {}".format(repository_arn))
230+
domain = m.group("domain")
231+
owner = m.group("account")
232+
repository = m.group("repository")
233+
region = m.group("region")
234+
235+
logger.info(
236+
"configuring pip to use codeartifact "
237+
"(domain: %s, domain owner: %s, repository: %s, region: %s)",
238+
domain,
239+
owner,
240+
repository,
241+
region,
242+
)
243+
try:
244+
client = boto3.client("codeartifact", region_name=region)
245+
auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner)
246+
token = auth_token_response["authorizationToken"]
247+
endpoint_response = client.get_repository_endpoint(
248+
domain=domain, domainOwner=owner, repository=repository, format="pypi"
249+
)
250+
unauthenticated_index = endpoint_response["repositoryEndpoint"]
251+
return re.sub(
252+
"https://",
253+
"https://aws:{}@".format(token),
254+
re.sub(
255+
"{}/?$".format(repository),
256+
"{}/simple/".format(repository),
257+
unauthenticated_index,
258+
),
259+
)
260+
except Exception:
261+
logger.error("failed to configure pip to use codeartifact")
262+
raise Exception("failed to configure pip to use codeartifact")
263+
264+
210265
def _retry_retrieve_mms_server_process(startup_timeout):
211266
retrieve_mms_server_process = retry(wait_fixed=1000, stop_max_delay=startup_timeout * 1000)(
212267
_retrieve_mms_server_process

test/unit/test_model_server.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
import os
1414
import signal
1515
import subprocess
16+
import sys
1617
import types
1718

18-
from mock import ANY, Mock, patch
19+
import botocore.session
20+
from botocore.stub import Stubber
21+
from mock import ANY, MagicMock, Mock, patch
1922
import pytest
2023

2124
from sagemaker_inference import environment, model_server
@@ -224,6 +227,16 @@ def test_add_sigterm_handler(signal_call):
224227
def test_install_requirements(check_call):
225228
model_server._install_requirements()
226229

230+
install_cmd = [
231+
sys.executable,
232+
"-m",
233+
"pip",
234+
"install",
235+
"-r",
236+
"/opt/ml/model/code/requirements.txt",
237+
]
238+
check_call.assert_called_once_with(install_cmd)
239+
227240

228241
@patch("subprocess.check_call", side_effect=subprocess.CalledProcessError(0, "cmd"))
229242
def test_install_requirements_installation_failed(check_call):
@@ -233,6 +246,49 @@ def test_install_requirements_installation_failed(check_call):
233246
assert "failed to install required packages" in str(e.value)
234247

235248

249+
@patch.dict(os.environ, {"CA_REPOSITORY_ARN": "invalid_arn"}, clear=True)
250+
def test_install_requirements_codeartifact_invalid_arn_installation_failed():
251+
with pytest.raises(Exception) as e:
252+
model_server._install_requirements()
253+
254+
assert "invalid CodeArtifact repository arn invalid_arn" in str(e.value)
255+
256+
257+
@patch("subprocess.check_call")
258+
@patch.dict(
259+
os.environ,
260+
{
261+
"CA_REPOSITORY_ARN": "arn:aws:codeartifact:my_region:012345678900:repository/my_domain/my_repo"
262+
},
263+
clear=True,
264+
)
265+
def test_install_requirements_codeartifact(check_call):
266+
# mock/stub codeartifact client and its responses
267+
endpoint = "https://domain-012345678900.d.codeartifact.region.amazonaws.com/pypi/my_repo/"
268+
codeartifact = botocore.session.get_session().create_client(
269+
"codeartifact", region_name="myregion"
270+
)
271+
stubber = Stubber(codeartifact)
272+
stubber.add_response("get_authorization_token", {"authorizationToken": "the-auth-token"})
273+
stubber.add_response("get_repository_endpoint", {"repositoryEndpoint": endpoint})
274+
stubber.activate()
275+
276+
with patch("boto3.client", MagicMock(return_value=codeartifact)):
277+
model_server._install_requirements()
278+
279+
install_cmd = [
280+
sys.executable,
281+
"-m",
282+
"pip",
283+
"install",
284+
"-r",
285+
"/opt/ml/model/code/requirements.txt",
286+
"-i",
287+
"https://aws:the-auth-token@domain-012345678900.d.codeartifact.region.amazonaws.com/pypi/my_repo/simple/",
288+
]
289+
check_call.assert_called_once_with(install_cmd)
290+
291+
236292
@patch("psutil.process_iter")
237293
def test_retrieve_mms_server_process(process_iter):
238294
server = Mock()

0 commit comments

Comments
 (0)