Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 6fd1c32

Browse files
authored
Merge branch 'master' into dbeatty10/dbt-profiles-dbt-project
2 parents 9798d79 + a614784 commit 6fd1c32

File tree

2 files changed

+96
-83
lines changed

2 files changed

+96
-83
lines changed

data_diff/dbt.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class DiffVars:
6363
primary_keys: List[str]
6464
datasource_id: str
6565
connection: Dict[str, str]
66+
threads: Optional[int]
6667

6768

6869
def dbt_diff(
@@ -127,16 +128,16 @@ def _get_diff_vars(
127128
dev_qualified_list = [dev_database, dev_schema, model.alias]
128129
prod_qualified_list = [prod_database, prod_schema, model.alias]
129130

130-
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection)
131+
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads)
131132

132133

133134
def _local_diff(diff_vars: DiffVars) -> None:
134135
column_diffs_str = ""
135136
dev_qualified_string = ".".join(diff_vars.dev_path)
136137
prod_qualified_string = ".".join(diff_vars.prod_path)
137138

138-
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys))
139-
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys))
139+
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
140+
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
140141

141142
table1_columns = list(table1.get_schema())
142143
try:
@@ -274,6 +275,7 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
274275
self.connection = None
275276
self.project_dict = None
276277
self.requires_upper = False
278+
self.threads = None
277279

278280
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
279281

@@ -319,95 +321,100 @@ def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
319321
with open(profiles_path) as profiles:
320322
profiles = self.yaml.safe_load(profiles)
321323

322-
dbt_profile = self.project_dict.get("profile")
324+
dbt_profile_var = self.project_dict.get("profile")
323325

324-
profile_outputs = get_from_dict_with_raise(
325-
profiles, dbt_profile, f"No profile '{dbt_profile}' found in '{profiles_path}'."
326+
profile = get_from_dict_with_raise(
327+
profiles, dbt_profile_var, f"No profile '{dbt_profile_var}' found in '{profiles_path}'."
326328
)
329+
# values can contain env_vars
330+
rendered_profile = self.ProfileRenderer().render_data(profile)
327331
profile_target = get_from_dict_with_raise(
328-
profile_outputs, "target", f"No target found in profile '{dbt_profile}' in '{profiles_path}'."
332+
rendered_profile, "target", f"No target found in profile '{dbt_profile_var}' in '{profiles_path}'."
329333
)
330334
outputs = get_from_dict_with_raise(
331-
profile_outputs, "outputs", f"No outputs found in profile '{dbt_profile}' in '{profiles_path}'."
335+
rendered_profile, "outputs", f"No outputs found in profile '{dbt_profile_var}' in '{profiles_path}'."
332336
)
333337
credentials = get_from_dict_with_raise(
334338
outputs,
335339
profile_target,
336-
f"No credentials found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
340+
f"No credentials found for target '{profile_target}' in profile '{dbt_profile_var}' in '{profiles_path}'.",
337341
)
338342
conn_type = get_from_dict_with_raise(
339343
credentials,
340344
"type",
341-
f"No type found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
345+
f"No type found for target '{profile_target}' in profile '{dbt_profile_var}' in '{profiles_path}'.",
342346
)
343347
conn_type = conn_type.lower()
344348

345-
# values can contain env_vars
346-
rendered_credentials = self.ProfileRenderer().render_data(credentials)
347-
return rendered_credentials, conn_type
349+
return credentials, conn_type
348350

349351
def set_connection(self):
350-
rendered_credentials, conn_type = self._get_connection_creds()
352+
credentials, conn_type = self._get_connection_creds()
351353

352354
if conn_type == "snowflake":
353-
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:
355+
if credentials.get("password") is None or credentials.get("private_key_path") is not None:
354356
raise Exception("Only password authentication is currently supported for Snowflake.")
355357
conn_info = {
356358
"driver": conn_type,
357-
"user": rendered_credentials.get("user"),
358-
"password": rendered_credentials.get("password"),
359-
"account": rendered_credentials.get("account"),
360-
"database": rendered_credentials.get("database"),
361-
"warehouse": rendered_credentials.get("warehouse"),
362-
"role": rendered_credentials.get("role"),
363-
"schema": rendered_credentials.get("schema"),
359+
"user": credentials.get("user"),
360+
"password": credentials.get("password"),
361+
"account": credentials.get("account"),
362+
"database": credentials.get("database"),
363+
"warehouse": credentials.get("warehouse"),
364+
"role": credentials.get("role"),
365+
"schema": credentials.get("schema"),
364366
}
367+
self.threads = credentials.get("threads")
365368
self.requires_upper = True
366369
elif conn_type == "bigquery":
367-
method = rendered_credentials.get("method")
370+
method = credentials.get("method")
368371
# there are many connection types https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#oauth-via-gcloud
369372
# this assumes that the user is auth'd via `gcloud auth application-default login`
370373
if method is None or method != "oauth":
371374
raise Exception("Oauth is the current method supported for Big Query.")
372375
conn_info = {
373376
"driver": conn_type,
374-
"project": rendered_credentials.get("project"),
375-
"dataset": rendered_credentials.get("dataset"),
377+
"project": credentials.get("project"),
378+
"dataset": credentials.get("dataset"),
376379
}
380+
self.threads = credentials.get("threads")
377381
elif conn_type == "duckdb":
378382
conn_info = {
379383
"driver": conn_type,
380-
"filepath": rendered_credentials.get("path"),
384+
"filepath": credentials.get("path"),
381385
}
382386
elif conn_type == "redshift":
383-
if rendered_credentials.get("password") is None or rendered_credentials.get("method") == "iam":
387+
if credentials.get("password") is None or credentials.get("method") == "iam":
384388
raise Exception("Only password authentication is currently supported for Redshift.")
385389
conn_info = {
386390
"driver": conn_type,
387-
"host": rendered_credentials.get("host"),
388-
"user": rendered_credentials.get("user"),
389-
"password": rendered_credentials.get("password"),
390-
"port": rendered_credentials.get("port"),
391-
"dbname": rendered_credentials.get("dbname"),
391+
"host": credentials.get("host"),
392+
"user": credentials.get("user"),
393+
"password": credentials.get("password"),
394+
"port": credentials.get("port"),
395+
"dbname": credentials.get("dbname"),
392396
}
397+
self.threads = credentials.get("threads")
393398
elif conn_type == "databricks":
394399
conn_info = {
395400
"driver": conn_type,
396-
"catalog": rendered_credentials.get("catalog"),
397-
"server_hostname": rendered_credentials.get("host"),
398-
"http_path": rendered_credentials.get("http_path"),
399-
"schema": rendered_credentials.get("schema"),
400-
"access_token": rendered_credentials.get("token"),
401+
"catalog": credentials.get("catalog"),
402+
"server_hostname": credentials.get("host"),
403+
"http_path": credentials.get("http_path"),
404+
"schema": credentials.get("schema"),
405+
"access_token": credentials.get("token"),
401406
}
407+
self.threads = credentials.get("threads")
402408
elif conn_type == "postgres":
403409
conn_info = {
404410
"driver": "postgresql",
405-
"host": rendered_credentials.get("host"),
406-
"user": rendered_credentials.get("user"),
407-
"password": rendered_credentials.get("password"),
408-
"port": rendered_credentials.get("port"),
409-
"dbname": rendered_credentials.get("dbname") or rendered_credentials.get("database"),
411+
"host": credentials.get("host"),
412+
"user": credentials.get("user"),
413+
"password": credentials.get("password"),
414+
"port": credentials.get("port"),
415+
"dbname": credentials.get("dbname") or credentials.get("database"),
410416
}
417+
self.threads = credentials.get("threads")
411418
else:
412419
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")
413420

0 commit comments

Comments
 (0)