Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit a614784

Browse files
authored
Merge pull request #395 from dlawin/issue_386
render jinja in entire selected profile
2 parents 7bda9c8 + d9894c1 commit a614784

File tree

2 files changed

+75
-69
lines changed

2 files changed

+75
-69
lines changed

data_diff/dbt.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -304,100 +304,100 @@ def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
304304
with open(profiles_path) as profiles:
305305
profiles = self.yaml.safe_load(profiles)
306306

307-
dbt_profile = self.project_dict.get("profile")
307+
dbt_profile_var = self.project_dict.get("profile")
308308

309-
profile_outputs = get_from_dict_with_raise(
310-
profiles, dbt_profile, f"No profile '{dbt_profile}' found in '{profiles_path}'."
309+
profile = get_from_dict_with_raise(
310+
profiles, dbt_profile_var, f"No profile '{dbt_profile_var}' found in '{profiles_path}'."
311311
)
312+
# values can contain env_vars
313+
rendered_profile = self.ProfileRenderer().render_data(profile)
312314
profile_target = get_from_dict_with_raise(
313-
profile_outputs, "target", f"No target found in profile '{dbt_profile}' in '{profiles_path}'."
315+
rendered_profile, "target", f"No target found in profile '{dbt_profile_var}' in '{profiles_path}'."
314316
)
315317
outputs = get_from_dict_with_raise(
316-
profile_outputs, "outputs", f"No outputs found in profile '{dbt_profile}' in '{profiles_path}'."
318+
rendered_profile, "outputs", f"No outputs found in profile '{dbt_profile_var}' in '{profiles_path}'."
317319
)
318320
credentials = get_from_dict_with_raise(
319321
outputs,
320322
profile_target,
321-
f"No credentials found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
323+
f"No credentials found for target '{profile_target}' in profile '{dbt_profile_var}' in '{profiles_path}'.",
322324
)
323325
conn_type = get_from_dict_with_raise(
324326
credentials,
325327
"type",
326-
f"No type found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
328+
f"No type found for target '{profile_target}' in profile '{dbt_profile_var}' in '{profiles_path}'.",
327329
)
328330
conn_type = conn_type.lower()
329331

330-
# values can contain env_vars
331-
rendered_credentials = self.ProfileRenderer().render_data(credentials)
332-
return rendered_credentials, conn_type
332+
return credentials, conn_type
333333

334334
def set_connection(self):
335-
rendered_credentials, conn_type = self._get_connection_creds()
335+
credentials, conn_type = self._get_connection_creds()
336336

337337
if conn_type == "snowflake":
338-
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:
338+
if credentials.get("password") is None or credentials.get("private_key_path") is not None:
339339
raise Exception("Only password authentication is currently supported for Snowflake.")
340340
conn_info = {
341341
"driver": conn_type,
342-
"user": rendered_credentials.get("user"),
343-
"password": rendered_credentials.get("password"),
344-
"account": rendered_credentials.get("account"),
345-
"database": rendered_credentials.get("database"),
346-
"warehouse": rendered_credentials.get("warehouse"),
347-
"role": rendered_credentials.get("role"),
348-
"schema": rendered_credentials.get("schema"),
342+
"user": credentials.get("user"),
343+
"password": credentials.get("password"),
344+
"account": credentials.get("account"),
345+
"database": credentials.get("database"),
346+
"warehouse": credentials.get("warehouse"),
347+
"role": credentials.get("role"),
348+
"schema": credentials.get("schema"),
349349
}
350-
self.threads = rendered_credentials.get("threads")
350+
self.threads = credentials.get("threads")
351351
self.requires_upper = True
352352
elif conn_type == "bigquery":
353-
method = rendered_credentials.get("method")
353+
method = credentials.get("method")
354354
# there are many connection types https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#oauth-via-gcloud
355355
# this assumes that the user is auth'd via `gcloud auth application-default login`
356356
if method is None or method != "oauth":
357357
raise Exception("Oauth is the current method supported for Big Query.")
358358
conn_info = {
359359
"driver": conn_type,
360-
"project": rendered_credentials.get("project"),
361-
"dataset": rendered_credentials.get("dataset"),
360+
"project": credentials.get("project"),
361+
"dataset": credentials.get("dataset"),
362362
}
363-
self.threads = rendered_credentials.get("threads")
363+
self.threads = credentials.get("threads")
364364
elif conn_type == "duckdb":
365365
conn_info = {
366366
"driver": conn_type,
367-
"filepath": rendered_credentials.get("path"),
367+
"filepath": credentials.get("path"),
368368
}
369369
elif conn_type == "redshift":
370-
if rendered_credentials.get("password") is None or rendered_credentials.get("method") == "iam":
370+
if credentials.get("password") is None or credentials.get("method") == "iam":
371371
raise Exception("Only password authentication is currently supported for Redshift.")
372372
conn_info = {
373373
"driver": conn_type,
374-
"host": rendered_credentials.get("host"),
375-
"user": rendered_credentials.get("user"),
376-
"password": rendered_credentials.get("password"),
377-
"port": rendered_credentials.get("port"),
378-
"dbname": rendered_credentials.get("dbname"),
374+
"host": credentials.get("host"),
375+
"user": credentials.get("user"),
376+
"password": credentials.get("password"),
377+
"port": credentials.get("port"),
378+
"dbname": credentials.get("dbname"),
379379
}
380-
self.threads = rendered_credentials.get("threads")
380+
self.threads = credentials.get("threads")
381381
elif conn_type == "databricks":
382382
conn_info = {
383383
"driver": conn_type,
384-
"catalog": rendered_credentials.get("catalog"),
385-
"server_hostname": rendered_credentials.get("host"),
386-
"http_path": rendered_credentials.get("http_path"),
387-
"schema": rendered_credentials.get("schema"),
388-
"access_token": rendered_credentials.get("token"),
384+
"catalog": credentials.get("catalog"),
385+
"server_hostname": credentials.get("host"),
386+
"http_path": credentials.get("http_path"),
387+
"schema": credentials.get("schema"),
388+
"access_token": credentials.get("token"),
389389
}
390-
self.threads = rendered_credentials.get("threads")
390+
self.threads = credentials.get("threads")
391391
elif conn_type == "postgres":
392392
conn_info = {
393393
"driver": "postgresql",
394-
"host": rendered_credentials.get("host"),
395-
"user": rendered_credentials.get("user"),
396-
"password": rendered_credentials.get("password"),
397-
"port": rendered_credentials.get("port"),
398-
"dbname": rendered_credentials.get("dbname") or rendered_credentials.get("database"),
394+
"host": credentials.get("host"),
395+
"user": credentials.get("user"),
396+
"password": credentials.get("password"),
397+
"port": credentials.get("port"),
398+
"dbname": credentials.get("dbname") or credentials.get("database"),
399399
}
400-
self.threads = rendered_credentials.get("threads")
400+
self.threads = credentials.get("threads")
401401
else:
402402
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")
403403

tests/test_dbt.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -211,37 +211,40 @@ def test_set_connection_not_implemented(self):
211211

212212
@patch("builtins.open", new_callable=mock_open, read_data="")
213213
def test_get_connection_creds_success(self, mock_open):
214-
profile_dict = {
214+
profiles_dict = {
215215
"a_profile": {
216216
"outputs": {
217217
"a_target": {"type": "TYPE1", "credential_1": "credential_1", "credential_2": "credential_2"}
218218
},
219219
"target": "a_target",
220220
}
221221
}
222-
expected_credentials = profile_dict["a_profile"]["outputs"]["a_target"]
222+
profile = profiles_dict["a_profile"]
223+
expected_credentials = profiles_dict["a_profile"]["outputs"]["a_target"]
223224
mock_self = Mock()
224225
mock_self.profiles_dir = ""
225226
mock_self.project_dict = {"profile": "a_profile"}
226-
mock_self.yaml.safe_load.return_value = profile_dict
227-
mock_self.ProfileRenderer().render_data.return_value = expected_credentials
227+
mock_self.yaml.safe_load.return_value = profiles_dict
228+
mock_self.ProfileRenderer().render_data.return_value = profile
228229
credentials, conn_type = DbtParser._get_connection_creds(mock_self)
229230
self.assertEqual(credentials, expected_credentials)
230231
self.assertEqual(conn_type, "type1")
231232

232233
@patch("builtins.open", new_callable=mock_open, read_data="")
233234
def test_get_connection_no_matching_profile(self, mock_open):
234-
profile_dict = {"a_profile": {}}
235+
profiles_dict = {"a_profile": {}}
235236
mock_self = Mock()
236237
mock_self.profiles_dir = ""
237238
mock_self.project_dict = {"profile": "wrong_profile"}
238-
mock_self.yaml.safe_load.return_value = profile_dict
239+
mock_self.yaml.safe_load.return_value = profiles_dict
240+
profile = profiles_dict["a_profile"]
241+
mock_self.ProfileRenderer().render_data.return_value = profile
239242
with self.assertRaises(ValueError):
240243
_, _ = DbtParser._get_connection_creds(mock_self)
241244

242245
@patch("builtins.open", new_callable=mock_open, read_data="")
243246
def test_get_connection_no_target(self, mock_open):
244-
profile_dict = {
247+
profiles_dict = {
245248
"a_profile": {
246249
"outputs": {
247250
"a_target": {"type": "TYPE1", "credential_1": "credential_1", "credential_2": "credential_2"}
@@ -250,8 +253,10 @@ def test_get_connection_no_target(self, mock_open):
250253
}
251254
mock_self = Mock()
252255
mock_self.profiles_dir = ""
256+
profile = profiles_dict["a_profile"]
257+
mock_self.ProfileRenderer().render_data.return_value = profile
253258
mock_self.project_dict = {"profile": "a_profile"}
254-
mock_self.yaml.safe_load.return_value = profile_dict
259+
mock_self.yaml.safe_load.return_value = profiles_dict
255260
with self.assertRaises(ValueError):
256261
_, _ = DbtParser._get_connection_creds(mock_self)
257262

@@ -262,24 +267,19 @@ def test_get_connection_no_target(self, mock_open):
262267

263268
@patch("builtins.open", new_callable=mock_open, read_data="")
264269
def test_get_connection_no_outputs(self, mock_open):
265-
profile_dict = {"a_profile": {"target": "a_target"}}
270+
profiles_dict = {"a_profile": {"target": "a_target"}}
266271
mock_self = Mock()
267272
mock_self.profiles_dir = ""
268273
mock_self.project_dict = {"profile": "a_profile"}
269-
mock_self.yaml.safe_load.return_value = profile_dict
274+
profile = profiles_dict["a_profile"]
275+
mock_self.ProfileRenderer().render_data.return_value = profile
276+
mock_self.yaml.safe_load.return_value = profiles_dict
270277
with self.assertRaises(ValueError):
271278
_, _ = DbtParser._get_connection_creds(mock_self)
272279

273-
profile_yaml_no_credentials = """
274-
a_profile:
275-
outputs:
276-
a_target:
277-
target: a_target
278-
"""
279-
280280
@patch("builtins.open", new_callable=mock_open, read_data="")
281281
def test_get_connection_no_credentials(self, mock_open):
282-
profile_dict = {
282+
profiles_dict = {
283283
"a_profile": {
284284
"outputs": {"a_target": {}},
285285
"target": "a_target",
@@ -288,13 +288,15 @@ def test_get_connection_no_credentials(self, mock_open):
288288
mock_self = Mock()
289289
mock_self.profiles_dir = ""
290290
mock_self.project_dict = {"profile": "a_profile"}
291-
mock_self.yaml.safe_load.return_value = profile_dict
291+
mock_self.yaml.safe_load.return_value = profiles_dict
292+
profile = profiles_dict["a_profile"]
293+
mock_self.ProfileRenderer().render_data.return_value = profile
292294
with self.assertRaises(ValueError):
293295
_, _ = DbtParser._get_connection_creds(mock_self)
294296

295297
@patch("builtins.open", new_callable=mock_open, read_data="")
296298
def test_get_connection_no_target_credentials(self, mock_open):
297-
profile_dict = {
299+
profiles_dict = {
298300
"a_profile": {
299301
"outputs": {
300302
"a_target": {"type": "TYPE1", "credential_1": "credential_1", "credential_2": "credential_2"}
@@ -305,13 +307,15 @@ def test_get_connection_no_target_credentials(self, mock_open):
305307
mock_self = Mock()
306308
mock_self.profiles_dir = ""
307309
mock_self.project_dict = {"profile": "a_profile"}
308-
mock_self.yaml.safe_load.return_value = profile_dict
310+
profile = profiles_dict["a_profile"]
311+
mock_self.ProfileRenderer().render_data.return_value = profile
312+
mock_self.yaml.safe_load.return_value = profiles_dict
309313
with self.assertRaises(ValueError):
310314
_, _ = DbtParser._get_connection_creds(mock_self)
311315

312316
@patch("builtins.open", new_callable=mock_open, read_data="")
313317
def test_get_connection_no_type(self, mock_open):
314-
profile_dict = {
318+
profiles_dict = {
315319
"a_profile": {
316320
"outputs": {"a_target": {"credential_1": "credential_1", "credential_2": "credential_2"}},
317321
"target": "a_target",
@@ -320,7 +324,9 @@ def test_get_connection_no_type(self, mock_open):
320324
mock_self = Mock()
321325
mock_self.profiles_dir = ""
322326
mock_self.project_dict = {"profile": "a_profile"}
323-
mock_self.yaml.safe_load.return_value = profile_dict
327+
mock_self.yaml.safe_load.return_value = profiles_dict
328+
profile = profiles_dict["a_profile"]
329+
mock_self.ProfileRenderer().render_data.return_value = profile
324330
with self.assertRaises(ValueError):
325331
_, _ = DbtParser._get_connection_creds(mock_self)
326332

@@ -366,7 +372,7 @@ def test_local_diff(self, mock_diff_tables):
366372
mock_diff_tables.assert_called_once_with(
367373
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY
368374
)
369-
self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2)
375+
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
370376
self.assertEqual(mock_connect.call_count, 2)
371377
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
372378
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
@@ -393,7 +399,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
393399
mock_diff_tables.assert_called_once_with(
394400
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY
395401
)
396-
self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2)
402+
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
397403
self.assertEqual(mock_connect.call_count, 2)
398404
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
399405
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)

0 commit comments

Comments
 (0)