|
15 | 15 | import copy |
16 | 16 | import json |
17 | 17 |
|
| 18 | +import os |
18 | 19 | import pytest |
19 | 20 | from mock import Mock |
20 | 21 |
|
|
26 | 27 | HyperparameterTuner, _TuningJob, WarmStartConfig, create_identical_dataset_and_algorithm_tuner, \ |
27 | 28 | create_transfer_learning_tuner, WarmStartTypes |
28 | 29 | from sagemaker.mxnet import MXNet |
| 30 | + |
| 31 | +DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') |
29 | 32 | MODEL_DATA = "s3://bucket/model.tar.gz" |
30 | 33 |
|
31 | 34 | JOB_NAME = 'tuning_job' |
@@ -488,6 +491,22 @@ def test_delete_endpoint(tuner): |
488 | 491 | tuner.sagemaker_session.delete_endpoint.assert_called_with(JOB_NAME) |
489 | 492 |
|
490 | 493 |
|
| 494 | +def test_fit_no_inputs(tuner, sagemaker_session): |
| 495 | + script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py') |
| 496 | + tuner.estimator = MXNet(entry_point=script_path, |
| 497 | + role=ROLE, |
| 498 | + framework_version=FRAMEWORK_VERSION, |
| 499 | + train_instance_count=TRAIN_INSTANCE_COUNT, |
| 500 | + train_instance_type=TRAIN_INSTANCE_TYPE, |
| 501 | + sagemaker_session=sagemaker_session) |
| 502 | + |
| 503 | + tuner.fit() |
| 504 | + |
| 505 | + _, _, tune_kwargs = sagemaker_session.tune.mock_calls[0] |
| 506 | + |
| 507 | + assert tune_kwargs['input_config'] is None |
| 508 | + |
| 509 | + |
491 | 510 | def test_identical_dataset_and_algorithm_tuner(sagemaker_session): |
492 | 511 | job_details = copy.deepcopy(TUNING_JOB_DETAILS) |
493 | 512 | sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', |
@@ -523,6 +542,8 @@ def test_transfer_learning_tuner(sagemaker_session): |
523 | 542 | assert parent_tuner.warm_start_config.type == WarmStartTypes.TRANSFER_LEARNING |
524 | 543 | assert parent_tuner.warm_start_config.parents == {tuner.latest_tuning_job.name, "p1", "p2"} |
525 | 544 | assert parent_tuner.estimator == tuner.estimator |
| 545 | + |
| 546 | + |
526 | 547 | ################################################################################# |
527 | 548 | # _ParameterRange Tests |
528 | 549 |
|
|
0 commit comments