|
1 | 1 | from __future__ import absolute_import |
2 | 2 |
|
3 | | -import uuid |
4 | 3 | import time |
| 4 | +import uuid |
| 5 | +from contextlib import contextmanager |
5 | 6 |
|
6 | 7 | import pytest |
7 | 8 |
|
8 | 9 | from sagemaker.analytics import ExperimentAnalytics |
9 | 10 |
|
10 | 11 |
|
11 | | -@pytest.mark.canary_quick |
12 | | -def test_experiment_analytics(sagemaker_session): |
| 12 | +@contextmanager |
| 13 | +def experiment(sagemaker_session): |
13 | 14 | sm = sagemaker_session.sagemaker_client |
| 15 | + trials = {} # for resource cleanup |
14 | 16 |
|
15 | 17 | experiment_name = "experiment-" + str(uuid.uuid4()) |
16 | | - sm.create_experiment(ExperimentName=experiment_name) |
17 | | - |
18 | | - for i in range(5): |
19 | | - trial_name = "trial-" + str(uuid.uuid4()) |
20 | | - sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
21 | | - trial_component_name = "tc-" + str(uuid.uuid4()) |
22 | | - sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training") |
23 | | - sm.update_trial_component( |
24 | | - TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
25 | | - ) |
26 | | - sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name) |
27 | | - |
28 | | - time.sleep(15) # wait for search to get updated |
29 | | - |
30 | | - analytics = ExperimentAnalytics( |
31 | | - experiment_name=experiment_name, sagemaker_session=sagemaker_session |
32 | | - ) |
| 18 | + try: |
| 19 | + sm.create_experiment(ExperimentName=experiment_name) |
33 | 20 |
|
34 | | - assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
| 21 | + # Search returns 10 results by default. Add 20 trials to verify pagination. |
| 22 | + for i in range(20): |
| 23 | + trial_name = "trial-" + str(uuid.uuid4()) |
| 24 | + sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
35 | 25 |
|
| 26 | + trial_component_name = "tc-" + str(uuid.uuid4()) |
| 27 | + trials[trial_name] = trial_component_name |
36 | 28 |
|
37 | | -def test_experiment_analytics_pagination(sagemaker_session): |
38 | | - sm = sagemaker_session.sagemaker_client |
39 | | - |
40 | | - experiment_name = "experiment" + str(uuid.uuid4()) |
41 | | - sm.create_experiment(ExperimentName=experiment_name) |
42 | | - |
43 | | - # Search returns 10 results by default. Add 20 trials to verify pagination, |
44 | | - for i in range(20): |
45 | | - trial_name = "trial-" + str(uuid.uuid4()) |
46 | | - sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
47 | | - trial_component_name = "tc-" + str(uuid.uuid4()) |
48 | | - sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training") |
49 | | - sm.update_trial_component( |
50 | | - TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
51 | | - ) |
52 | | - sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name) |
| 29 | + sm.create_trial_component( |
| 30 | + TrialComponentName=trial_component_name, DisplayName="Training" |
| 31 | + ) |
| 32 | + sm.update_trial_component( |
| 33 | + TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
| 34 | + ) |
| 35 | + sm.associate_trial_component( |
| 36 | + TrialComponentName=trial_component_name, TrialName=trial_name |
| 37 | + ) |
53 | 38 |
|
54 | | - time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry |
| 39 | + time.sleep(15) # wait for search to get updated |
55 | 40 |
|
56 | | - analytics = ExperimentAnalytics( |
57 | | - experiment_name=experiment_name, sagemaker_session=sagemaker_session |
58 | | - ) |
| 41 | + yield experiment_name |
| 42 | + finally: |
| 43 | + _delete_resources(sm, experiment_name, trials) |
59 | 44 |
|
60 | | - assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
61 | | - assert ( |
62 | | - len(analytics.dataframe()) > 10 |
63 | | - ) # TODO [owen-t] Replace with == 20 and put test in retry block |
64 | 45 |
|
| 46 | +@pytest.mark.canary_quick |
| 47 | +def test_experiment_analytics(sagemaker_session): |
| 48 | + with experiment(sagemaker_session) as experiment_name: |
| 49 | + analytics = ExperimentAnalytics( |
| 50 | + experiment_name=experiment_name, sagemaker_session=sagemaker_session |
| 51 | + ) |
65 | 52 |
|
66 | | -def test_experiment_analytics_search_by_nested_filter(sagemaker_session): |
67 | | - sm = sagemaker_session.sagemaker_client |
| 53 | + assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
68 | 54 |
|
69 | | - experiment_name = "experiment" + str(uuid.uuid4()) |
70 | | - sm.create_experiment(ExperimentName=experiment_name) |
71 | 55 |
|
72 | | - for i in range(20): |
73 | | - trial_name = "trial-" + str(uuid.uuid4()) |
74 | | - sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
75 | | - trial_component_name = "tc-" + str(uuid.uuid4()) |
76 | | - sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training") |
77 | | - sm.update_trial_component( |
78 | | - TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
| 56 | +def test_experiment_analytics_pagination(sagemaker_session): |
| 57 | + with experiment(sagemaker_session) as experiment_name: |
| 58 | + analytics = ExperimentAnalytics( |
| 59 | + experiment_name=experiment_name, sagemaker_session=sagemaker_session |
79 | 60 | ) |
80 | | - sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name) |
81 | 61 |
|
82 | | - time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry |
| 62 | + assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
| 63 | + assert ( |
| 64 | + len(analytics.dataframe()) > 10 |
| 65 | + ) # TODO [owen-t] Replace with == 20 and put test in retry block |
83 | 66 |
|
84 | | - search_exp = { |
85 | | - "Filters": [ |
86 | | - {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
87 | | - {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
88 | | - ] |
89 | | - } |
90 | 67 |
|
91 | | - analytics = ExperimentAnalytics( |
92 | | - sagemaker_session=sagemaker_session, search_expression=search_exp |
93 | | - ) |
| 68 | +def test_experiment_analytics_search_by_nested_filter(sagemaker_session): |
| 69 | + with experiment(sagemaker_session) as experiment_name: |
| 70 | + search_exp = { |
| 71 | + "Filters": [ |
| 72 | + {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
| 73 | + {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
| 74 | + ] |
| 75 | + } |
| 76 | + |
| 77 | + analytics = ExperimentAnalytics( |
| 78 | + sagemaker_session=sagemaker_session, search_expression=search_exp |
| 79 | + ) |
94 | 80 |
|
95 | | - assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
96 | | - assert ( |
97 | | - len(analytics.dataframe()) > 5 |
98 | | - ) # TODO [owen-t] Replace with == 10 and put test in retry block |
| 81 | + assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
| 82 | + assert ( |
| 83 | + len(analytics.dataframe()) > 5 |
| 84 | + ) # TODO [owen-t] Replace with == 10 and put test in retry block |
99 | 85 |
|
100 | 86 |
|
101 | 87 | def test_experiment_analytics_search_by_nested_filter_sort_ascending(sagemaker_session): |
102 | | - sm = sagemaker_session.sagemaker_client |
| 88 | + with experiment(sagemaker_session) as experiment_name: |
| 89 | + search_exp = { |
| 90 | + "Filters": [ |
| 91 | + {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
| 92 | + {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
| 93 | + ] |
| 94 | + } |
| 95 | + |
| 96 | + analytics = ExperimentAnalytics( |
| 97 | + sagemaker_session=sagemaker_session, |
| 98 | + search_expression=search_exp, |
| 99 | + sort_by="Parameters.hp1", |
| 100 | + sort_order="Ascending", |
| 101 | + ) |
| 102 | + |
| 103 | + assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
| 104 | + assert ( |
| 105 | + len(analytics.dataframe()) > 5 |
| 106 | + ) # TODO [owen-t] Replace with == 10 and put test in retry block |
| 107 | + assert list(analytics.dataframe()["hp1"].values) == sorted( |
| 108 | + analytics.dataframe()["hp1"].values |
| 109 | + ) |
103 | 110 |
|
104 | | - experiment_name = "experiment" + str(uuid.uuid4()) |
105 | | - sm.create_experiment(ExperimentName=experiment_name) |
106 | 111 |
|
107 | | - for i in range(20): |
108 | | - trial_name = "trial-" + str(uuid.uuid4()) |
109 | | - sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
110 | | - trial_component_name = "tc-" + str(uuid.uuid4()) |
111 | | - sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training") |
112 | | - sm.update_trial_component( |
113 | | - TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
| 112 | +def test_experiment_analytics_search_by_nested_filter_sort_descending(sagemaker_session): |
| 113 | + with experiment(sagemaker_session) as experiment_name: |
| 114 | + search_exp = { |
| 115 | + "Filters": [ |
| 116 | + {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
| 117 | + {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
| 118 | + ] |
| 119 | + } |
| 120 | + |
| 121 | + analytics = ExperimentAnalytics( |
| 122 | + sagemaker_session=sagemaker_session, |
| 123 | + search_expression=search_exp, |
| 124 | + sort_by="Parameters.hp1", |
114 | 125 | ) |
115 | | - sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name) |
116 | 126 |
|
117 | | - time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry |
| 127 | + assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
| 128 | + assert ( |
| 129 | + len(analytics.dataframe()) > 5 |
| 130 | + ) # TODO [owen-t] Replace with == 10 and put test in retry block |
| 131 | + assert ( |
| 132 | + list(analytics.dataframe()["hp1"].values) |
| 133 | + == sorted(analytics.dataframe()["hp1"].values)[::-1] |
| 134 | + ) |
118 | 135 |
|
119 | | - search_exp = { |
120 | | - "Filters": [ |
121 | | - {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
122 | | - {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
123 | | - ] |
124 | | - } |
125 | 136 |
|
126 | | - analytics = ExperimentAnalytics( |
127 | | - sagemaker_session=sagemaker_session, |
128 | | - search_expression=search_exp, |
129 | | - sort_by="Parameters.hp1", |
130 | | - sort_order="Ascending", |
131 | | - ) |
| 137 | +def _delete_resources(sagemaker_client, experiment_name, trials): |
| 138 | + for trial, tc in trials.items(): |
| 139 | + with _ignore_resource_not_found(sagemaker_client): |
| 140 | + sagemaker_client.disassociate_trial_component(TrialName=trial, TrialComponentName=tc) |
132 | 141 |
|
133 | | - assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
134 | | - assert ( |
135 | | - len(analytics.dataframe()) > 5 |
136 | | - ) # TODO [owen-t] Replace with == 10 and put test in retry block |
137 | | - assert list(analytics.dataframe()["hp1"].values) == sorted(analytics.dataframe()["hp1"].values) |
| 142 | + with _ignore_resource_not_found(sagemaker_client): |
| 143 | + sagemaker_client.delete_trial_component(TrialComponentName=tc) |
138 | 144 |
|
| 145 | + with _ignore_resource_not_found(sagemaker_client): |
| 146 | + sagemaker_client.delete_trial(TrialName=trial) |
139 | 147 |
|
140 | | -def test_experiment_analytics_search_by_nested_filter_sort_descending(sagemaker_session): |
141 | | - sm = sagemaker_session.sagemaker_client |
| 148 | + with _ignore_resource_not_found(sagemaker_client): |
| 149 | + sagemaker_client.delete_experiment(ExperimentName=experiment_name) |
142 | 150 |
|
143 | | - experiment_name = "experiment" + str(uuid.uuid4()) |
144 | | - sm.create_experiment(ExperimentName=experiment_name) |
145 | 151 |
|
146 | | - for i in range(20): |
147 | | - trial_name = "trial-" + str(uuid.uuid4()) |
148 | | - sm.create_trial(TrialName=trial_name, ExperimentName=experiment_name) |
149 | | - trial_component_name = "tc-" + str(uuid.uuid4()) |
150 | | - sm.create_trial_component(TrialComponentName=trial_component_name, DisplayName="Training") |
151 | | - sm.update_trial_component( |
152 | | - TrialComponentName=trial_component_name, Parameters={"hp1": {"NumberValue": i}} |
153 | | - ) |
154 | | - sm.associate_trial_component(TrialComponentName=trial_component_name, TrialName=trial_name) |
155 | | - |
156 | | - time.sleep(15) # wait for search to get updated TODO [owen-t]: Replace with retry |
157 | | - |
158 | | - search_exp = { |
159 | | - "Filters": [ |
160 | | - {"Name": "Parents.ExperimentName", "Operator": "Equals", "Value": experiment_name}, |
161 | | - {"Name": "Parameters.hp1", "Operator": "GreaterThanOrEqualTo", "Value": "10"}, |
162 | | - ] |
163 | | - } |
164 | | - |
165 | | - analytics = ExperimentAnalytics( |
166 | | - sagemaker_session=sagemaker_session, search_expression=search_exp, sort_by="Parameters.hp1" |
167 | | - ) |
168 | | - |
169 | | - assert list(analytics.dataframe().columns) == ["TrialComponentName", "DisplayName", "hp1"] |
170 | | - assert ( |
171 | | - len(analytics.dataframe()) > 5 |
172 | | - ) # TODO [owen-t] Replace with == 10 and put test in retry block |
173 | | - assert ( |
174 | | - list(analytics.dataframe()["hp1"].values) |
175 | | - == sorted(analytics.dataframe()["hp1"].values)[::-1] |
176 | | - ) |
| 152 | +@contextmanager |
| 153 | +def _ignore_resource_not_found(sagemaker_client): |
| 154 | + try: |
| 155 | + yield |
| 156 | + except sagemaker_client.exceptions.ResourceNotFound: |
| 157 | + pass |
0 commit comments