Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 7bda9c8

Browse files
authored
Merge pull request #435 from dlawin/issue_425
issue 425 parse and use threads
2 parents 0a8f3f3 + 62ec256 commit 7bda9c8

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

data_diff/dbt.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class DiffVars:
4646
primary_keys: List[str]
4747
datasource_id: str
4848
connection: Dict[str, str]
49+
threads: Optional[int]
4950

5051

5152
def dbt_diff(
@@ -110,16 +111,16 @@ def _get_diff_vars(
110111
dev_qualified_list = [dev_database, dev_schema, model.alias]
111112
prod_qualified_list = [prod_database, prod_schema, model.alias]
112113

113-
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection)
114+
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads)
114115

115116

116117
def _local_diff(diff_vars: DiffVars) -> None:
117118
column_diffs_str = ""
118119
dev_qualified_string = ".".join(diff_vars.dev_path)
119120
prod_qualified_string = ".".join(diff_vars.prod_path)
120121

121-
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys))
122-
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys))
122+
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
123+
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
123124

124125
table1_columns = list(table1.get_schema())
125126
try:
@@ -260,6 +261,7 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
260261
self.connection = None
261262
self.project_dict = None
262263
self.requires_upper = False
264+
self.threads = None
263265

264266
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
265267

@@ -345,6 +347,7 @@ def set_connection(self):
345347
"role": rendered_credentials.get("role"),
346348
"schema": rendered_credentials.get("schema"),
347349
}
350+
self.threads = rendered_credentials.get("threads")
348351
self.requires_upper = True
349352
elif conn_type == "bigquery":
350353
method = rendered_credentials.get("method")
@@ -357,6 +360,7 @@ def set_connection(self):
357360
"project": rendered_credentials.get("project"),
358361
"dataset": rendered_credentials.get("dataset"),
359362
}
363+
self.threads = rendered_credentials.get("threads")
360364
elif conn_type == "duckdb":
361365
conn_info = {
362366
"driver": conn_type,
@@ -373,6 +377,7 @@ def set_connection(self):
373377
"port": rendered_credentials.get("port"),
374378
"dbname": rendered_credentials.get("dbname"),
375379
}
380+
self.threads = rendered_credentials.get("threads")
376381
elif conn_type == "databricks":
377382
conn_info = {
378383
"driver": conn_type,
@@ -382,6 +387,7 @@ def set_connection(self):
382387
"schema": rendered_credentials.get("schema"),
383388
"access_token": rendered_credentials.get("token"),
384389
}
390+
self.threads = rendered_credentials.get("threads")
385391
elif conn_type == "postgres":
386392
conn_info = {
387393
"driver": "postgresql",
@@ -391,6 +397,7 @@ def set_connection(self):
391397
"port": rendered_credentials.get("port"),
392398
"dbname": rendered_credentials.get("dbname") or rendered_credentials.get("database"),
393399
}
400+
self.threads = rendered_credentials.get("threads")
394401
else:
395402
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")
396403

tests/test_dbt.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def test_local_diff(self, mock_diff_tables):
359359
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
360360
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
361361
expected_keys = ["key"]
362-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection)
362+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection, None)
363363
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
364364
_local_diff(diff_vars)
365365

@@ -368,8 +368,8 @@ def test_local_diff(self, mock_diff_tables):
368368
)
369369
self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2)
370370
self.assertEqual(mock_connect.call_count, 2)
371-
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
372-
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
371+
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
372+
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
373373
mock_diff.get_stats_string.assert_called_once()
374374

375375
@patch("data_diff.dbt.diff_tables")
@@ -386,7 +386,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
386386
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
387387
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
388388
expected_keys = ["primary_key_column"]
389-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection)
389+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection, None)
390390
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
391391
_local_diff(diff_vars)
392392

@@ -395,8 +395,8 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
395395
)
396396
self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2)
397397
self.assertEqual(mock_connect.call_count, 2)
398-
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
399-
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
398+
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
399+
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
400400
mock_diff.get_stats_string.assert_not_called()
401401

402402
@patch("data_diff.dbt.rich.print")
@@ -413,7 +413,7 @@ def test_cloud_diff(self, mock_request, mock_os_environ, mock_print):
413413
expected_datasource_id = 1
414414
expected_primary_keys = ["primary_key_column"]
415415
diff_vars = DiffVars(
416-
dev_qualified_list, prod_qualified_list, expected_primary_keys, expected_datasource_id, None
416+
dev_qualified_list, prod_qualified_list, expected_primary_keys, expected_datasource_id, None, None
417417
)
418418
_cloud_diff(diff_vars)
419419

@@ -443,7 +443,7 @@ def test_cloud_diff_ds_id_none(self, mock_request, mock_os_environ, mock_print):
443443
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
444444
expected_datasource_id = None
445445
primary_keys = ["primary_key_column"]
446-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None)
446+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None, None)
447447
with self.assertRaises(ValueError):
448448
_cloud_diff(diff_vars)
449449

@@ -463,7 +463,7 @@ def test_cloud_diff_api_key_none(self, mock_request, mock_os_environ, mock_print
463463
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
464464
expected_datasource_id = 1
465465
primary_keys = ["primary_key_column"]
466-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None)
466+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, expected_datasource_id, None, None)
467467
with self.assertRaises(ValueError):
468468
_cloud_diff(diff_vars)
469469

@@ -487,7 +487,7 @@ def test_diff_is_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_
487487
mock_dbt_parser.return_value = mock_dbt_parser_inst
488488
mock_dbt_parser_inst.get_models.return_value = [mock_model]
489489
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
490-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
490+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
491491
mock_get_diff_vars.return_value = expected_diff_vars
492492
dbt_diff(is_cloud=True)
493493
mock_dbt_parser_inst.get_models.assert_called_once()
@@ -514,7 +514,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
514514
}
515515
mock_dbt_parser_inst.get_models.return_value = [mock_model]
516516
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
517-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
517+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
518518
mock_get_diff_vars.return_value = expected_diff_vars
519519
dbt_diff(is_cloud=False)
520520

@@ -542,7 +542,7 @@ def test_diff_no_prod_configs(
542542

543543
mock_dbt_parser_inst.get_models.return_value = [mock_model]
544544
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
545-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
545+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
546546
mock_get_diff_vars.return_value = expected_diff_vars
547547
with self.assertRaises(ValueError):
548548
dbt_diff(is_cloud=False)
@@ -570,7 +570,7 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
570570
}
571571
mock_dbt_parser_inst.get_models.return_value = [mock_model]
572572
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
573-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
573+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
574574
mock_get_diff_vars.return_value = expected_diff_vars
575575
dbt_diff(is_cloud=False)
576576

@@ -599,7 +599,7 @@ def test_diff_only_prod_schema(
599599

600600
mock_dbt_parser_inst.get_models.return_value = [mock_model]
601601
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
602-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None)
602+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], 123, None, None)
603603
mock_get_diff_vars.return_value = expected_diff_vars
604604
with self.assertRaises(ValueError):
605605
dbt_diff(is_cloud=False)
@@ -631,7 +631,7 @@ def test_diff_is_cloud_no_pks(
631631

632632
mock_dbt_parser_inst.get_models.return_value = [mock_model]
633633
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
634-
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None)
634+
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None, None)
635635
mock_get_diff_vars.return_value = expected_diff_vars
636636
dbt_diff(is_cloud=True)
637637

@@ -662,7 +662,7 @@ def test_diff_not_is_cloud_no_pks(
662662
mock_dbt_parser_inst.get_models.return_value = [mock_model]
663663
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
664664

665-
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None)
665+
expected_diff_vars = DiffVars(["dev"], ["prod"], [], 123, None, None)
666666
mock_get_diff_vars.return_value = expected_diff_vars
667667
dbt_diff(is_cloud=False)
668668
mock_dbt_parser_inst.get_models.assert_called_once()

0 commit comments

Comments
 (0)