Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 9f56621

Browse files
committed
Merge branch 'master' into issue_386
2 parents 024bc47 + 7bda9c8 commit 9f56621

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

data_diff/dbt.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class DiffVars:
4646
primary_keys: List[str]
4747
datasource_id: str
4848
connection: Dict[str, str]
49+
threads: Optional[int]
4950

5051

5152
def dbt_diff(
@@ -110,16 +111,16 @@ def _get_diff_vars(
110111
dev_qualified_list = [dev_database, dev_schema, model.alias]
111112
prod_qualified_list = [prod_database, prod_schema, model.alias]
112113

113-
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection)
114+
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads)
114115

115116

116117
def _local_diff(diff_vars: DiffVars) -> None:
117118
column_diffs_str = ""
118119
dev_qualified_string = ".".join(diff_vars.dev_path)
119120
prod_qualified_string = ".".join(diff_vars.prod_path)
120121

121-
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys))
122-
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys))
122+
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
123+
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
123124

124125
table1_columns = list(table1.get_schema())
125126
try:
@@ -260,6 +261,7 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
260261
self.connection = None
261262
self.project_dict = None
262263
self.requires_upper = False
264+
self.threads = None
263265

264266
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
265267

@@ -345,6 +347,7 @@ def set_connection(self):
345347
"role": credentials.get("role"),
346348
"schema": credentials.get("schema"),
347349
}
350+
self.threads = rendered_credentials.get("threads")
348351
self.requires_upper = True
349352
elif conn_type == "bigquery":
350353
method = credentials.get("method")
@@ -357,6 +360,7 @@ def set_connection(self):
357360
"project": credentials.get("project"),
358361
"dataset": credentials.get("dataset"),
359362
}
363+
self.threads = rendered_credentials.get("threads")
360364
elif conn_type == "duckdb":
361365
conn_info = {
362366
"driver": conn_type,
@@ -373,6 +377,7 @@ def set_connection(self):
373377
"port": credentials.get("port"),
374378
"dbname": credentials.get("dbname"),
375379
}
380+
self.threads = rendered_credentials.get("threads")
376381
elif conn_type == "databricks":
377382
conn_info = {
378383
"driver": conn_type,
@@ -382,6 +387,7 @@ def set_connection(self):
382387
"schema": credentials.get("schema"),
383388
"access_token": credentials.get("token"),
384389
}
390+
self.threads = rendered_credentials.get("threads")
385391
elif conn_type == "postgres":
386392
conn_info = {
387393
"driver": "postgresql",
@@ -391,6 +397,7 @@ def set_connection(self):
391397
"port": credentials.get("port"),
392398
"dbname": credentials.get("dbname") or credentials.get("database"),
393399
}
400+
self.threads = rendered_credentials.get("threads")
394401
else:
395402
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")
396403

tests/test_dbt.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def test_local_diff(self, mock_diff_tables):
365365
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
366366
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
367367
expected_keys = ["key"]
368-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection)
368+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection, None)
369369
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
370370
_local_diff(diff_vars)
371371

@@ -374,8 +374,8 @@ def test_local_diff(self, mock_diff_tables):
374374
)
375375
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
376376
self.assertEqual(mock_connect.call_count, 2)
377-
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
378-
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
377+
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
378+
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
379379
mock_diff.get_stats_string.assert_called_once()
380380

381381
@patch("data_diff.dbt.diff_tables")
@@ -392,7 +392,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
392392
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
393393
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
394394
expected_keys = ["primary_key_column"]
395-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection)
395+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection, None)
396396
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
397397
_local_diff(diff_vars)
398398

@@ -401,8 +401,8 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
401401
)
402402
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
403403
self.assertEqual(mock_connect.call_count, 2)
404-
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
405-
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
404+
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
405+
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
406406
mock_diff.get_stats_string.assert_not_called()
407407

408408
@patch("data_diff.dbt.rich.print")
@@ -419,7 +419,7 @@ def test_cloud_diff(self, mock_request, mock_os_environ, mock_print):
419419
expected_datasource_id = 1
420420
expected_primary_keys = ["primary_key_column"]
421421
diff_vars = DiffVars(
422-
dev_qualified_list, prod_qualified_list, expected_primary_keys, expected_datasource_id, None
422+
dev_qualified_list, prod_qualified_list, expected_primary_keys, expected_datasource_id, None, None
423423
)
424424
_cloud_diff(diff_vars)
425425

@@ -449,7 +449,7 @@ def test_cloud_diff_ds_id_none(self, mock_request, mock_os_environ, mock_print):
449449
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
450450
expected_datasource_id = None
451451
primary_keys = ["primary_key_column"]
452-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None)
452+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None, None)
453453
with self.assertRaises(ValueError):
454454
_cloud_diff(diff_vars)
455455

@@ -469,7 +469,7 @@ def test_cloud_diff_api_key_none(self, mock_request, mock_os_environ, mock_print
469469
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
470470
expected_datasource_id = 1
471471
primary_keys = ["primary_key_column"]
472-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None)
472+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None, None)
473473
with self.assertRaises(ValueError):
474474
_cloud_diff(diff_vars)
475475

@@ -493,7 +493,7 @@ def test_diff_is_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_
493493
mock_dbt_parser.return_value = mock_dbt_parser_inst
494494
mock_dbt_parser_inst.get_models.return_value = [mock_model]
495495
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
496-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
496+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
497497
mock_get_diff_vars.return_value = expected_diff_vars
498498
dbt_diff(is_cloud=True)
499499
mock_dbt_parser_inst.get_models.assert_called_once()
@@ -520,7 +520,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
520520
}
521521
mock_dbt_parser_inst.get_models.return_value = [mock_model]
522522
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
523-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
523+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
524524
mock_get_diff_vars.return_value = expected_diff_vars
525525
dbt_diff(is_cloud=False)
526526

@@ -548,7 +548,7 @@ def test_diff_no_prod_configs(
548548

549549
mock_dbt_parser_inst.get_models.return_value = [mock_model]
550550
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
551-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
551+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
552552
mock_get_diff_vars.return_value = expected_diff_vars
553553
with self.assertRaises(ValueError):
554554
dbt_diff(is_cloud=False)
@@ -576,7 +576,7 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
576576
}
577577
mock_dbt_parser_inst.get_models.return_value = [mock_model]
578578
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
579-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
579+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
580580
mock_get_diff_vars.return_value = expected_diff_vars
581581
dbt_diff(is_cloud=False)
582582

@@ -605,7 +605,7 @@ def test_diff_only_prod_schema(
605605

606606
mock_dbt_parser_inst.get_models.return_value = [mock_model]
607607
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
608-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
608+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
609609
mock_get_diff_vars.return_value = expected_diff_vars
610610
with self.assertRaises(ValueError):
611611
dbt_diff(is_cloud=False)
@@ -637,7 +637,7 @@ def test_diff_is_cloud_no_pks(
637637

638638
mock_dbt_parser_inst.get_models.return_value = [mock_model]
639639
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
640-
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None)
640+
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None, None)
641641
mock_get_diff_vars.return_value = expected_diff_vars
642642
dbt_diff(is_cloud=True)
643643

@@ -668,7 +668,7 @@ def test_diff_not_is_cloud_no_pks(
668668
mock_dbt_parser_inst.get_models.return_value = [mock_model]
669669
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
670670

671-
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None)
671+
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None, None)
672672
mock_get_diff_vars.return_value = expected_diff_vars
673673
dbt_diff(is_cloud=False)
674674
mock_dbt_parser_inst.get_models.assert_called_once()

0 commit comments

Comments
 (0)