@@ -60,6 +60,7 @@ def dbt_diff(
6060 config_prod_database = datadiff_variables .get ("prod_database" )
6161 config_prod_schema = datadiff_variables .get ("prod_schema" )
6262 datasource_id = datadiff_variables .get ("datasource_id" )
63+ custom_schema = datadiff_variables .get ("custom_schema" )
6364
6465 if not is_cloud :
6566 dbt_parser .set_connection ()
@@ -70,7 +71,9 @@ def dbt_diff(
7071 )
7172
7273 for model in models :
73- diff_vars = _get_diff_vars (dbt_parser , config_prod_database , config_prod_schema , model , datasource_id )
74+ diff_vars = _get_diff_vars (
75+ dbt_parser , config_prod_database , config_prod_schema , model , datasource_id , custom_schema
76+ )
7477
7578 if is_cloud and len (diff_vars .primary_keys ) > 0 :
7679 _cloud_diff (diff_vars )
@@ -95,6 +98,7 @@ def _get_diff_vars(
9598 config_prod_schema : Optional [str ],
9699 model ,
97100 datasource_id : int ,
101+ custom_schema : Optional [bool ],
98102) -> DiffVars :
99103 dev_database = model .database
100104 dev_schema = model .schema_
@@ -103,6 +107,12 @@ def _get_diff_vars(
103107 prod_database = config_prod_database if config_prod_database else dev_database
104108 prod_schema = config_prod_schema if config_prod_schema else dev_schema
105109
110+ # if project has custom_schemas: True
111+ # need to construct the prod schema as <prod_target_schema>_<custom_schema>
112+ # https://docs.getdbt.com/docs/build/custom-schemas
113+ if custom_schema and model .config .schema_ :
114+ prod_schema = prod_schema + "_" + model .config .schema_
115+
106116 if dbt_parser .requires_upper :
107117 dev_qualified_list = [x .upper () for x in [dev_database , dev_schema , model .alias ]]
108118 prod_qualified_list = [x .upper () for x in [prod_database , prod_schema , model .alias ]]
@@ -111,16 +121,22 @@ def _get_diff_vars(
111121 dev_qualified_list = [dev_database , dev_schema , model .alias ]
112122 prod_qualified_list = [prod_database , prod_schema , model .alias ]
113123
114- return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection , dbt_parser .threads )
124+ return DiffVars (
125+ dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection , dbt_parser .threads
126+ )
115127
116128
117129def _local_diff (diff_vars : DiffVars ) -> None :
118130 column_diffs_str = ""
119131 dev_qualified_string = "." .join (diff_vars .dev_path )
120132 prod_qualified_string = "." .join (diff_vars .prod_path )
121133
122- table1 = connect_to_table (diff_vars .connection , dev_qualified_string , tuple (diff_vars .primary_keys ), diff_vars .threads )
123- table2 = connect_to_table (diff_vars .connection , prod_qualified_string , tuple (diff_vars .primary_keys ), diff_vars .threads )
134+ table1 = connect_to_table (
135+ diff_vars .connection , dev_qualified_string , tuple (diff_vars .primary_keys ), diff_vars .threads
136+ )
137+ table2 = connect_to_table (
138+ diff_vars .connection , prod_qualified_string , tuple (diff_vars .primary_keys ), diff_vars .threads
139+ )
124140
125141 table1_columns = list (table1 .get_schema ())
126142 try :
0 commit comments