1- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+ # Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License"). You
44# may not use this file except in compliance with the License. A copy of
1414
1515import unittest .mock
1616
17- import pytest
18- from sagemaker .lineage import visualizer
1917import pandas as pd
2018from collections import OrderedDict
2119
2220
23- @pytest .fixture
24- def sagemaker_session ():
25- return unittest .mock .Mock ()
26-
27-
28- @pytest .fixture
29- def vizualizer (sagemaker_session ):
30- return visualizer .LineageTableVisualizer (sagemaker_session )
31-
32-
33- def test_friendly_name_short_uri (vizualizer , sagemaker_session ):
21+ def test_friendly_name_short_uri (viz , sagemaker_session ):
3422 uri = "s3://f-069083975568/train.txt"
3523 arn = "test_arn"
3624 sagemaker_session .sagemaker_client .describe_artifact .return_value = {
3725 "Source" : {"SourceUri" : uri , "SourceTypes" : "" }
3826 }
39- actual_name = vizualizer ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
27+ actual_name = viz ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
4028 assert uri == actual_name
4129
4230
43- def test_friendly_name_long_uri (vizualizer , sagemaker_session ):
31+ def test_friendly_name_long_uri (viz , sagemaker_session ):
4432 uri = (
4533 "s3://flintstone-end-to-end-tests-gamma-us-west-2-069083975568/results/canary-auto-1608761252626/"
4634 "preprocessed-data/tuning_data/train.txt"
@@ -49,12 +37,12 @@ def test_friendly_name_long_uri(vizualizer, sagemaker_session):
4937 sagemaker_session .sagemaker_client .describe_artifact .return_value = {
5038 "Source" : {"SourceUri" : uri , "SourceTypes" : "" }
5139 }
52- actual_name = vizualizer ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
40+ actual_name = viz ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
5341 expected_name = "s3://.../preprocessed-data/tuning_data/train.txt"
5442 assert expected_name == actual_name
5543
5644
57- def test_trial_component_name (sagemaker_session , vizualizer ):
45+ def test_trial_component_name (viz , sagemaker_session ):
5846 name = "tc-name"
5947
6048 sagemaker_session .sagemaker_client .describe_trial_component .return_value = {
@@ -90,7 +78,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
9078 },
9179 ]
9280
93- df = vizualizer .show (trial_component_name = name )
81+ df = viz .show (trial_component_name = name )
9482
9583 sagemaker_session .sagemaker_client .describe_trial_component .assert_called_with (
9684 TrialComponentName = name ,
@@ -121,7 +109,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
121109 pd .testing .assert_frame_equal (expected_dataframe , df )
122110
123111
124- def test_model_package_arn (sagemaker_session , vizualizer ):
112+ def test_model_package_arn (viz , sagemaker_session ):
125113 name = "model_package_arn"
126114
127115 sagemaker_session .sagemaker_client .list_artifacts .return_value = {
@@ -157,7 +145,7 @@ def test_model_package_arn(sagemaker_session, vizualizer):
157145 },
158146 ]
159147
160- df = vizualizer .show (model_package_arn = name )
148+ df = viz .show (model_package_arn = name )
161149
162150 sagemaker_session .sagemaker_client .list_artifacts .assert_called_with (
163151 SourceUri = name ,
@@ -188,7 +176,7 @@ def test_model_package_arn(sagemaker_session, vizualizer):
188176 pd .testing .assert_frame_equal (expected_dataframe , df )
189177
190178
191- def test_endpoint_arn (sagemaker_session , vizualizer ):
179+ def test_endpoint_arn (viz , sagemaker_session ):
192180 name = "endpoint_arn"
193181
194182 sagemaker_session .sagemaker_client .list_contexts .return_value = {
@@ -224,7 +212,7 @@ def test_endpoint_arn(sagemaker_session, vizualizer):
224212 },
225213 ]
226214
227- df = vizualizer .show (endpoint_arn = name )
215+ df = viz .show (endpoint_arn = name )
228216
229217 sagemaker_session .sagemaker_client .list_contexts .assert_called_with (
230218 SourceUri = name ,
@@ -253,3 +241,71 @@ def test_endpoint_arn(sagemaker_session, vizualizer):
253241 )
254242
255243 pd .testing .assert_frame_equal (expected_dataframe , df )
244+
245+
246+ def test_processing_job_pipeline_execution_step (viz , sagemaker_session ):
247+
248+ sagemaker_session .sagemaker_client .list_trial_components .return_value = {
249+ "TrialComponentSummaries" : [{"TrialComponentArn" : "tc-arn" }]
250+ }
251+
252+ sagemaker_session .sagemaker_client .list_associations .side_effect = [
253+ {
254+ "AssociationSummaries" : [
255+ {
256+ "SourceArn" : "a:b:c:d:e:artifact/src-arn-1" ,
257+ "SourceName" : "source-name-1" ,
258+ "SourceType" : "source-type-1" ,
259+ "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-1" ,
260+ "DestinationName" : "dest-name-1" ,
261+ "DestinationType" : "dest-type-1" ,
262+ "AssociationType" : "type-1" ,
263+ }
264+ ]
265+ },
266+ {
267+ "AssociationSummaries" : [
268+ {
269+ "SourceArn" : "a:b:c:d:e:artifact/src-arn-2" ,
270+ "SourceName" : "source-name-2" ,
271+ "SourceType" : "source-type-2" ,
272+ "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-2" ,
273+ "DestinationName" : "dest-name-2" ,
274+ "DestinationType" : "dest-type-2" ,
275+ "AssociationType" : "type-2" ,
276+ }
277+ ]
278+ },
279+ ]
280+
281+ step = {"Metadata" : {"ProcessingJob" : {"Arn" : "proc-job-arn" }}}
282+
283+ df = viz .show (pipeline_execution_step = step )
284+
285+ sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
286+ SourceArn = "proc-job-arn" ,
287+ )
288+
289+ expected_calls = [
290+ unittest .mock .call (
291+ DestinationArn = "tc-arn" ,
292+ ),
293+ unittest .mock .call (
294+ SourceArn = "tc-arn" ,
295+ ),
296+ ]
297+ assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
298+
299+ expected_dataframe = pd .DataFrame .from_dict (
300+ OrderedDict (
301+ [
302+ ("Name/Source" , ["source-name-1" , "dest-name-2" ]),
303+ ("Direction" , ["Input" , "Output" ]),
304+ ("Type" , ["source-type-1" , "dest-type-2" ]),
305+ ("Association Type" , ["type-1" , "type-2" ]),
306+ ("Lineage Type" , ["artifact" , "artifact" ]),
307+ ]
308+ )
309+ )
310+
311+ pd .testing .assert_frame_equal (expected_dataframe , df )
0 commit comments