Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit dbfbfeb

Browse files
committed
support custom schemas
1 parent 7bda9c8 commit dbfbfeb

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

data_diff/dbt.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def dbt_diff(
6060
config_prod_database = datadiff_variables.get("prod_database")
6161
config_prod_schema = datadiff_variables.get("prod_schema")
6262
datasource_id = datadiff_variables.get("datasource_id")
63+
custom_schema = datadiff_variables.get("custom_schema")
6364

6465
if not is_cloud:
6566
dbt_parser.set_connection()
@@ -70,7 +71,9 @@ def dbt_diff(
7071
)
7172

7273
for model in models:
73-
diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, datasource_id)
74+
diff_vars = _get_diff_vars(
75+
dbt_parser, config_prod_database, config_prod_schema, model, datasource_id, custom_schema
76+
)
7477

7578
if is_cloud and len(diff_vars.primary_keys) > 0:
7679
_cloud_diff(diff_vars)
@@ -95,6 +98,7 @@ def _get_diff_vars(
9598
config_prod_schema: Optional[str],
9699
model,
97100
datasource_id: int,
101+
custom_schema: Optional[bool],
98102
) -> DiffVars:
99103
dev_database = model.database
100104
dev_schema = model.schema_
@@ -103,6 +107,12 @@ def _get_diff_vars(
103107
prod_database = config_prod_database if config_prod_database else dev_database
104108
prod_schema = config_prod_schema if config_prod_schema else dev_schema
105109

110+
# if project has custom_schemas: True
111+
# need to construct the prod schema as <prod_target_schema>_<custom_schema>
112+
# https://docs.getdbt.com/docs/build/custom-schemas
113+
if custom_schema and model.config.schema_:
114+
prod_schema = prod_schema + "_" + model.config.schema_
115+
106116
if dbt_parser.requires_upper:
107117
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias]]
108118
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.alias]]
@@ -111,16 +121,22 @@ def _get_diff_vars(
111121
dev_qualified_list = [dev_database, dev_schema, model.alias]
112122
prod_qualified_list = [prod_database, prod_schema, model.alias]
113123

114-
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads)
124+
return DiffVars(
125+
dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads
126+
)
115127

116128

117129
def _local_diff(diff_vars: DiffVars) -> None:
118130
column_diffs_str = ""
119131
dev_qualified_string = ".".join(diff_vars.dev_path)
120132
prod_qualified_string = ".".join(diff_vars.prod_path)
121133

122-
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
123-
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
134+
table1 = connect_to_table(
135+
diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads
136+
)
137+
table2 = connect_to_table(
138+
diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads
139+
)
124140

125141
table1_columns = list(table1.get_schema())
126142
try:

0 commit comments

Comments
 (0)