@@ -49,64 +49,17 @@ def test_trial_component_name(viz, sagemaker_session):
4949 "TrialComponentArn" : "tc-arn" ,
5050 }
5151
52- sagemaker_session .sagemaker_client .list_associations .side_effect = [
53- {
54- "AssociationSummaries" : [
55- {
56- "SourceArn" : "a:b:c:d:e:artifact/src-arn-1" ,
57- "SourceName" : "source-name-1" ,
58- "SourceType" : "source-type-1" ,
59- "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-1" ,
60- "DestinationName" : "dest-name-1" ,
61- "DestinationType" : "dest-type-1" ,
62- "AssociationType" : "type-1" ,
63- }
64- ]
65- },
66- {
67- "AssociationSummaries" : [
68- {
69- "SourceArn" : "a:b:c:d:e:artifact/src-arn-2" ,
70- "SourceName" : "source-name-2" ,
71- "SourceType" : "source-type-2" ,
72- "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-2" ,
73- "DestinationName" : "dest-name-2" ,
74- "DestinationType" : "dest-type-2" ,
75- "AssociationType" : "type-2" ,
76- }
77- ]
78- },
79- ]
52+ get_list_associations_side_effect (sagemaker_session )
8053
8154 df = viz .show (trial_component_name = name )
8255
8356 sagemaker_session .sagemaker_client .describe_trial_component .assert_called_with (
8457 TrialComponentName = name ,
8558 )
8659
87- expected_calls = [
88- unittest .mock .call (
89- DestinationArn = "tc-arn" ,
90- ),
91- unittest .mock .call (
92- SourceArn = "tc-arn" ,
93- ),
94- ]
95- assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
60+ assert_list_associations_mock_calls (sagemaker_session )
9661
97- expected_dataframe = pd .DataFrame .from_dict (
98- OrderedDict (
99- [
100- ("Name/Source" , ["source-name-1" , "dest-name-2" ]),
101- ("Direction" , ["Input" , "Output" ]),
102- ("Type" , ["source-type-1" , "dest-type-2" ]),
103- ("Association Type" , ["type-1" , "type-2" ]),
104- ("Lineage Type" , ["artifact" , "artifact" ]),
105- ]
106- )
107- )
108-
109- pd .testing .assert_frame_equal (expected_dataframe , df )
62+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
11063
11164
11265def test_model_package_arn (viz , sagemaker_session ):
@@ -116,34 +69,7 @@ def test_model_package_arn(viz, sagemaker_session):
11669 "ArtifactSummaries" : [{"ArtifactArn" : "artifact-arn" }]
11770 }
11871
119- sagemaker_session .sagemaker_client .list_associations .side_effect = [
120- {
121- "AssociationSummaries" : [
122- {
123- "SourceArn" : "a:b:c:d:e:artifact/src-arn-1" ,
124- "SourceName" : "source-name-1" ,
125- "SourceType" : "source-type-1" ,
126- "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-1" ,
127- "DestinationName" : "dest-name-1" ,
128- "DestinationType" : "dest-type-1" ,
129- "AssociationType" : "type-1" ,
130- }
131- ]
132- },
133- {
134- "AssociationSummaries" : [
135- {
136- "SourceArn" : "a:b:c:d:e:artifact/src-arn-2" ,
137- "SourceName" : "source-name-2" ,
138- "SourceType" : "source-type-2" ,
139- "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-2" ,
140- "DestinationName" : "dest-name-2" ,
141- "DestinationType" : "dest-type-2" ,
142- "AssociationType" : "type-2" ,
143- }
144- ]
145- },
146- ]
72+ get_list_associations_side_effect (sagemaker_session )
14773
14874 df = viz .show (model_package_arn = name )
14975
@@ -161,19 +87,7 @@ def test_model_package_arn(viz, sagemaker_session):
16187 ]
16288 assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
16389
164- expected_dataframe = pd .DataFrame .from_dict (
165- OrderedDict (
166- [
167- ("Name/Source" , ["source-name-1" , "dest-name-2" ]),
168- ("Direction" , ["Input" , "Output" ]),
169- ("Type" , ["source-type-1" , "dest-type-2" ]),
170- ("Association Type" , ["type-1" , "type-2" ]),
171- ("Lineage Type" , ["artifact" , "artifact" ]),
172- ]
173- )
174- )
175-
176- pd .testing .assert_frame_equal (expected_dataframe , df )
90+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
17791
17892
17993def test_endpoint_arn (viz , sagemaker_session ):
@@ -183,34 +97,7 @@ def test_endpoint_arn(viz, sagemaker_session):
18397 "ContextSummaries" : [{"ContextArn" : "context-arn" }]
18498 }
18599
186- sagemaker_session .sagemaker_client .list_associations .side_effect = [
187- {
188- "AssociationSummaries" : [
189- {
190- "SourceArn" : "a:b:c:d:e:context/src-arn-1" ,
191- "SourceName" : "source-name-1" ,
192- "SourceType" : "source-type-1" ,
193- "DestinationArn" : "a:b:c:d:e:context/dest-arn-1" ,
194- "DestinationName" : "dest-name-1" ,
195- "DestinationType" : "dest-type-1" ,
196- "AssociationType" : "type-1" ,
197- }
198- ]
199- },
200- {
201- "AssociationSummaries" : [
202- {
203- "SourceArn" : "a:b:c:d:e:context/src-arn-2" ,
204- "SourceName" : "source-name-2" ,
205- "SourceType" : "source-type-2" ,
206- "DestinationArn" : "a:b:c:d:e:context/dest-arn-2" ,
207- "DestinationName" : "dest-name-2" ,
208- "DestinationType" : "dest-type-2" ,
209- "AssociationType" : "type-2" ,
210- }
211- ]
212- },
213- ]
100+ get_list_associations_side_effect (sagemaker_session )
214101
215102 df = viz .show (endpoint_arn = name )
216103
@@ -228,27 +115,74 @@ def test_endpoint_arn(viz, sagemaker_session):
228115 ]
229116 assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
230117
231- expected_dataframe = pd .DataFrame .from_dict (
232- OrderedDict (
233- [
234- ("Name/Source" , ["source-name-1" , "dest-name-2" ]),
235- ("Direction" , ["Input" , "Output" ]),
236- ("Type" , ["source-type-1" , "dest-type-2" ]),
237- ("Association Type" , ["type-1" , "type-2" ]),
238- ("Lineage Type" , ["context" , "context" ]),
239- ]
240- )
118+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
119+
120+
121+ def test_processing_job_pipeline_execution_step (viz , sagemaker_session ):
122+
123+ sagemaker_session .sagemaker_client .list_trial_components .return_value = {
124+ "TrialComponentSummaries" : [{"TrialComponentArn" : "tc-arn" }]
125+ }
126+
127+ get_list_associations_side_effect (sagemaker_session )
128+
129+ step = {"Metadata" : {"ProcessingJob" : {"Arn" : "proc-job-arn" }}}
130+
131+ df = viz .show (pipeline_execution_step = step )
132+
133+ sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
134+ SourceArn = "proc-job-arn" ,
241135 )
242136
243- pd . testing . assert_frame_equal ( expected_dataframe , df )
137+ assert_list_associations_mock_calls ( sagemaker_session )
244138
139+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
245140
246- def test_processing_job_pipeline_execution_step (viz , sagemaker_session ):
141+
142+ def test_training_job_pipeline_execution_step (viz , sagemaker_session ):
247143
248144 sagemaker_session .sagemaker_client .list_trial_components .return_value = {
249145 "TrialComponentSummaries" : [{"TrialComponentArn" : "tc-arn" }]
250146 }
251147
148+ get_list_associations_side_effect (sagemaker_session )
149+
150+ step = {"Metadata" : {"TrainingJob" : {"Arn" : "training-job-arn" }}}
151+
152+ df = viz .show (pipeline_execution_step = step )
153+
154+ sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
155+ SourceArn = "training-job-arn" ,
156+ )
157+
158+ assert_list_associations_mock_calls (sagemaker_session )
159+
160+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
161+
162+
163+ def test_transform_job_pipeline_execution_step (viz , sagemaker_session ):
164+
165+ sagemaker_session .sagemaker_client .list_trial_components .return_value = {
166+ "TrialComponentSummaries" : [{"TrialComponentArn" : "tc-arn" }]
167+ }
168+
169+ get_list_associations_side_effect (sagemaker_session )
170+
171+ step = {"Metadata" : {"TransformJob" : {"Arn" : "transform-job-arn" }}}
172+
173+ df = viz .show (pipeline_execution_step = step )
174+
175+ sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
176+ SourceArn = "transform-job-arn" ,
177+ )
178+
179+ assert_list_associations_mock_calls (sagemaker_session )
180+
181+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
182+
183+
184+ def get_list_associations_side_effect (sagemaker_session ):
185+
252186 sagemaker_session .sagemaker_client .list_associations .side_effect = [
253187 {
254188 "AssociationSummaries" : [
@@ -278,13 +212,8 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
278212 },
279213 ]
280214
281- step = {"Metadata" : {"ProcessingJob" : {"Arn" : "proc-job-arn" }}}
282-
283- df = viz .show (pipeline_execution_step = step )
284215
285- sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
286- SourceArn = "proc-job-arn" ,
287- )
216+ def assert_list_associations_mock_calls (sagemaker_session ):
288217
289218 expected_calls = [
290219 unittest .mock .call (
@@ -296,6 +225,9 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
296225 ]
297226 assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
298227
228+
229+ def get_expected_dataframe ():
230+
299231 expected_dataframe = pd .DataFrame .from_dict (
300232 OrderedDict (
301233 [
@@ -308,4 +240,4 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
308240 )
309241 )
310242
311- pd . testing . assert_frame_equal ( expected_dataframe , df )
243+ return expected_dataframe
0 commit comments