@@ -30,8 +30,9 @@ class TestDbtDiffer(unittest.TestCase):
3030 def test_integration_basic_dbt (self ):
3131 artifacts_path = os .getcwd () + "/tests/dbt_artifacts"
3232 test_project_path = os .environ .get ("DATA_DIFF_DBT_PROJ" ) or artifacts_path
33+ test_profiles_path = os .environ .get ("DATA_DIFF_DBT_PROFILES" ) or artifacts_path
3334 diff = run_datadiff_cli (
34- "--dbt" , "--dbt-project-dir" , test_project_path , "--dbt-profiles-dir" , test_project_path
35+ "--dbt" , "--dbt-project-dir" , test_project_path , "--dbt-profiles-dir" , test_profiles_path
3536 )
3637
3738 # assertions for the diff that exists in tests/dbt_artifacts/jaffle_shop.duckdb
@@ -933,3 +934,37 @@ def test_get_diff_vars_call_get_prod_path_from_manifest(
933934 mock_prod_path_from_manifest .assert_called_once_with (mock_model , mock_dbt_parser .prod_manifest_obj )
934935 self .assertEqual (diff_vars .prod_path [0 ], mock_prod_path_from_manifest .return_value [0 ])
935936 self .assertEqual (diff_vars .prod_path [1 ], mock_prod_path_from_manifest .return_value [1 ])
937+
938+ @patch ("data_diff.dbt._get_prod_path_from_config" )
939+ @patch ("data_diff.dbt._get_prod_path_from_manifest" )
940+ def test_get_diff_vars_cli_columns (self , mock_prod_path_from_manifest , mock_prod_path_from_config ):
941+ config = TDatadiffConfig (prod_database = "prod_db" )
942+ mock_model = Mock ()
943+ primary_keys = ["a_primary_key" ]
944+ mock_model .database = "a_dev_db"
945+ mock_model .schema_ = "a_schema"
946+ mock_model .config .schema_ = None
947+ mock_model .config .database = None
948+ mock_model .alias = "a_model_name"
949+ mock_model .unique_id = "unique_id"
950+ mock_tdatadiffmodelconfig = Mock ()
951+ mock_tdatadiffmodelconfig .where_filter = "where"
952+ mock_tdatadiffmodelconfig .include_columns = ["include" ]
953+ mock_tdatadiffmodelconfig .exclude_columns = ["exclude" ]
954+ mock_dbt_parser = Mock ()
955+ mock_dbt_parser .get_datadiff_model_config .return_value = mock_tdatadiffmodelconfig
956+ mock_dbt_parser .connection = {}
957+ mock_dbt_parser .threads = 0
958+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
959+ mock_dbt_parser .requires_upper = False
960+ mock_dbt_parser .prod_manifest_obj = None
961+ mock_prod_path_from_config .return_value = ("prod_db" , "prod_schema" )
962+ cli_columns = ("col1" , "col2" )
963+
964+ diff_vars = _get_diff_vars (mock_dbt_parser , config , mock_model , where_flag = None , columns_flag = cli_columns )
965+
966+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
967+ mock_prod_path_from_config .assert_called_once_with (config , mock_model , mock_model .database , mock_model .schema_ )
968+ mock_prod_path_from_manifest .assert_not_called ()
969+ self .assertEqual (diff_vars .include_columns , list (cli_columns ))
970+ self .assertEqual (diff_vars .exclude_columns , [])
0 commit comments