Skip to content

Commit dd724eb

Browse files
authored
handle max_request_size param set by SM platofrm (#121)
* handle max_request_size param set by SM platofrm * fix lint * update test with a readable value
1 parent 709949b commit dd724eb

File tree

4 files changed

+16
-0
lines changed

4 files changed

+16
-0
lines changed

src/sagemaker_inference/environment.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DEFAULT_STARTUP_TIMEOUT = "600" # 10 minutes
2727
DEFAULT_HTTP_PORT = "8080"
2828
DEFAULT_VMARGS = "-XX:-UseContainerSupport"
29+
DEFAULT_MAX_REQUEST_SIZE = None
2930

3031
SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str
3132

@@ -81,6 +82,9 @@ def __init__(self):
8182
self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT)
8283
self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV)
8384
self._vmargs = os.environ.get(parameters.MODEL_SERVER_VMARGS, DEFAULT_VMARGS)
85+
self._max_request_size_in_mb = os.environ.get(
86+
parameters.MAX_REQUEST_SIZE, DEFAULT_MAX_REQUEST_SIZE
87+
)
8488

8589
@staticmethod
8690
def _parse_module_name(program_param):
@@ -147,3 +151,11 @@ def safe_port_range(self): # type: () -> str
147151
def vmargs(self): # type: () -> str
148152
"""str: vmargs can be provided for the JVM, to be overriden"""
149153
return self._vmargs
154+
155+
@property
156+
def max_request_size(self): # type: () -> str
157+
"""str: max request size set by Sagemaker platform in bytes"""
158+
if self._max_request_size_in_mb is not None:
159+
return int(self._max_request_size_in_mb) * 1024 * 1024
160+
else:
161+
return None

src/sagemaker_inference/model_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def _generate_mms_config_properties(env, handler_service=None):
160160
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
161161
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
162162
"vmargs": env.vmargs,
163+
"max_request_size": env.max_request_size,
163164
}
164165
# If provided, add handler service to user config
165166
if handler_service:

src/sagemaker_inference/parameters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str
2626
SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str
2727
MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str
28+
MAX_REQUEST_SIZE = "SAGEMAKER_MAX_PAYLOAD_IN_MB" # type: str

test/unit/test_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
parameters.BIND_TO_PORT_ENV: "1738",
3030
parameters.SAFE_PORT_RANGE_ENV: "1111-2222",
3131
parameters.MODEL_SERVER_VMARGS: "-XX:-UseContainerSupport",
32+
parameters.MAX_REQUEST_SIZE: "10",
3233
},
3334
clear=True,
3435
)
@@ -47,6 +48,7 @@ def test_env():
4748
assert env.management_http_port == "1738"
4849
assert env.safe_port_range == "1111-2222"
4950
assert "-XX:-UseContainerSupport" in env.vmargs
51+
assert env.max_request_size == 10 * 1024 * 1024
5052

5153

5254
@pytest.mark.parametrize("sagemaker_program", ["program.py", "program"])

0 commit comments

Comments
 (0)