Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 024bc47

Browse files
committed
squash issue 386 -- render entire profile
1 parent 0a8f3f3 commit 024bc47

File tree

2 files changed

+70
-64
lines changed

2 files changed

+70
-64
lines changed

data_diff/dbt.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -302,94 +302,94 @@ def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
302302
with open(profiles_path) as profiles:
303303
profiles = self.yaml.safe_load(profiles)
304304

305-
dbt_profile = self.project_dict.get("profile")
305+
dbt_profile_var = self.project_dict.get("profile")
306306

307-
profile_outputs = get_from_dict_with_raise(
308-
profiles, dbt_profile, f"No profile '{dbt_profile}' found in '{profiles_path}'."
307+
profile = get_from_dict_with_raise(
308+
profiles, dbt_profile_var, f"No profile '{dbt_profile_var}' found in '{profiles_path}'."
309309
)
310+
# values can contain env_vars
311+
rendered_profile = self.ProfileRenderer().render_data(profile)
310312
profile_target = get_from_dict_with_raise(
311-
profile_outputs, "target", f"No target found in profile '{dbt_profile}' in '{profiles_path}'."
313+
rendered_profile, "target", f"No target found in profile '{dbt_profile_var}' in '{profiles_path}'."
312314
)
313315
outputs = get_from_dict_with_raise(
314-
profile_outputs, "outputs", f"No outputs found in profile '{dbt_profile}' in '{profiles_path}'."
316+
rendered_profile, "outputs", f"No outputs found in profile '{dbt_profile_var}' in '{profiles_path}'."
315317
)
316318
credentials = get_from_dict_with_raise(
317319
outputs,
318320
profile_target,
319-
f"No credentials found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
321+
f"No credentials found for target '{profile_target}' in profile '{dbt_profile_var}' in '{profiles_path}'.",
320322
)
321323
conn_type = get_from_dict_with_raise(
322324
credentials,
323325
"type",
324-
f"No type found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
326+
f"No type found for target '{profile_target}' in profile '{dbt_profile_var}' in '{profiles_path}'.",
325327
)
326328
conn_type = conn_type.lower()
327329

328-
# values can contain env_vars
329-
rendered_credentials = self.ProfileRenderer().render_data(credentials)
330-
return rendered_credentials, conn_type
330+
return credentials, conn_type
331331

332332
def set_connection(self):
333-
rendered_credentials, conn_type = self._get_connection_creds()
333+
credentials, conn_type = self._get_connection_creds()
334334

335335
if conn_type == "snowflake":
336-
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:
336+
if credentials.get("password") is None or credentials.get("private_key_path") is not None:
337337
raise Exception("Only password authentication is currently supported for Snowflake.")
338338
conn_info = {
339339
"driver": conn_type,
340-
"user": rendered_credentials.get("user"),
341-
"password": rendered_credentials.get("password"),
342-
"account": rendered_credentials.get("account"),
343-
"database": rendered_credentials.get("database"),
344-
"warehouse": rendered_credentials.get("warehouse"),
345-
"role": rendered_credentials.get("role"),
346-
"schema": rendered_credentials.get("schema"),
340+
"user": credentials.get("user"),
341+
"password": credentials.get("password"),
342+
"account": credentials.get("account"),
343+
"database": credentials.get("database"),
344+
"warehouse": credentials.get("warehouse"),
345+
"role": credentials.get("role"),
346+
"schema": credentials.get("schema"),
347347
}
348348
self.requires_upper = True
349349
elif conn_type == "bigquery":
350-
method = rendered_credentials.get("method")
350+
method = credentials.get("method")
351351
# there are many connection types https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#oauth-via-gcloud
352352
# this assumes that the user is auth'd via `gcloud auth application-default login`
353353
if method is None or method != "oauth":
354354
raise Exception("Oauth is the current method supported for Big Query.")
355355
conn_info = {
356356
"driver": conn_type,
357-
"project": rendered_credentials.get("project"),
358-
"dataset": rendered_credentials.get("dataset"),
357+
"project": credentials.get("project"),
358+
"dataset": credentials.get("dataset"),
359359
}
360360
elif conn_type == "duckdb":
361361
conn_info = {
362362
"driver": conn_type,
363-
"filepath": rendered_credentials.get("path"),
363+
"filepath": credentials.get("path"),
364364
}
365365
elif conn_type == "redshift":
366-
if rendered_credentials.get("password") is None or rendered_credentials.get("method") == "iam":
366+
if credentials.get("password") is None or credentials.get("method") == "iam":
367367
raise Exception("Only password authentication is currently supported for Redshift.")
368368
conn_info = {
369369
"driver": conn_type,
370-
"host": rendered_credentials.get("host"),
371-
"user": rendered_credentials.get("user"),
372-
"password": rendered_credentials.get("password"),
373-
"port": rendered_credentials.get("port"),
374-
"dbname": rendered_credentials.get("dbname"),
370+
"host": credentials.get("host"),
371+
"user": credentials.get("user"),
372+
"password": credentials.get("password"),
373+
"port": credentials.get("port"),
374+
"dbname": credentials.get("dbname"),
375375
}
376376
elif conn_type == "databricks":
377377
conn_info = {
378378
"driver": conn_type,
379-
"catalog": rendered_credentials.get("catalog"),
380-
"server_hostname": rendered_credentials.get("host"),
381-
"http_path": rendered_credentials.get("http_path"),
382-
"schema": rendered_credentials.get("schema"),
383-
"access_token": rendered_credentials.get("token"),
379+
"catalog": credentials.get("catalog"),
380+
"server_hostname": credentials.get("host"),
381+
"http_path": credentials.get("http_path"),
382+
"schema": credentials.get("schema"),
383+
"access_token": credentials.get("token"),
384384
}
385385
elif conn_type == "postgres":
386386
conn_info = {
387387
"driver": "postgresql",
388-
"host": rendered_credentials.get("host"),
389-
"user": rendered_credentials.get("user"),
390-
"password": rendered_credentials.get("password"),
391-
"port": rendered_credentials.get("port"),
392-
"dbname": rendered_credentials.get("dbname") or rendered_credentials.get("database"),
388+
"host": credentials.get("host"),
389+
"user": credentials.get("user"),
390+
"password": credentials.get("password"),
391+
"port": credentials.get("port"),
392+
"dbname": credentials.get("dbname") or credentials.get("database"),
393393
}
394394
else:
395395
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")

tests/test_dbt.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -211,37 +211,40 @@ def test_set_connection_not_implemented(self):
211211

212212
@patch("builtins.open", new_callable=mock_open, read_data="")
213213
def test_get_connection_creds_success(self, mock_open):
214-
profile_dict = {
214+
profiles_dict = {
215215
"a_profile": {
216216
"outputs": {
217217
"a_target": {"type": "TYPE1", "credential_1": "credential_1", "credential_2": "credential_2"}
218218
},
219219
"target": "a_target",
220220
}
221221
}
222-
expected_credentials = profile_dict["a_profile"]["outputs"]["a_target"]
222+
profile = profiles_dict["a_profile"]
223+
expected_credentials = profiles_dict["a_profile"]["outputs"]["a_target"]
223224
mock_self = Mock()
224225
mock_self.profiles_dir = ""
225226
mock_self.project_dict = {"profile": "a_profile"}
226-
mock_self.yaml.safe_load.return_value = profile_dict
227-
mock_self.ProfileRenderer().render_data.return_value = expected_credentials
227+
mock_self.yaml.safe_load.return_value = profiles_dict
228+
mock_self.ProfileRenderer().render_data.return_value = profile
228229
credentials, conn_type = DbtParser._get_connection_creds(mock_self)
229230
self.assertEqual(credentials, expected_credentials)
230231
self.assertEqual(conn_type, "type1")
231232

232233
@patch("builtins.open", new_callable=mock_open, read_data="")
233234
def test_get_connection_no_matching_profile(self, mock_open):
234-
profile_dict = {"a_profile": {}}
235+
profiles_dict = {"a_profile": {}}
235236
mock_self = Mock()
236237
mock_self.profiles_dir = ""
237238
mock_self.project_dict = {"profile": "wrong_profile"}
238-
mock_self.yaml.safe_load.return_value = profile_dict
239+
mock_self.yaml.safe_load.return_value = profiles_dict
240+
profile = profiles_dict["a_profile"]
241+
mock_self.ProfileRenderer().render_data.return_value = profile
239242
with self.assertRaises(ValueError):
240243
_, _ = DbtParser._get_connection_creds(mock_self)
241244

242245
@patch("builtins.open", new_callable=mock_open, read_data="")
243246
def test_get_connection_no_target(self, mock_open):
244-
profile_dict = {
247+
profiles_dict = {
245248
"a_profile": {
246249
"outputs": {
247250
"a_target": {"type": "TYPE1", "credential_1": "credential_1", "credential_2": "credential_2"}
@@ -250,8 +253,10 @@ def test_get_connection_no_target(self, mock_open):
250253
}
251254
mock_self = Mock()
252255
mock_self.profiles_dir = ""
256+
profile = profiles_dict["a_profile"]
257+
mock_self.ProfileRenderer().render_data.return_value = profile
253258
mock_self.project_dict = {"profile": "a_profile"}
254-
mock_self.yaml.safe_load.return_value = profile_dict
259+
mock_self.yaml.safe_load.return_value = profiles_dict
255260
with self.assertRaises(ValueError):
256261
_, _ = DbtParser._get_connection_creds(mock_self)
257262

@@ -262,24 +267,19 @@ def test_get_connection_no_target(self, mock_open):
262267

263268
@patch("builtins.open", new_callable=mock_open, read_data="")
264269
def test_get_connection_no_outputs(self, mock_open):
265-
profile_dict = {"a_profile": {"target": "a_target"}}
270+
profiles_dict = {"a_profile": {"target": "a_target"}}
266271
mock_self = Mock()
267272
mock_self.profiles_dir = ""
268273
mock_self.project_dict = {"profile": "a_profile"}
269-
mock_self.yaml.safe_load.return_value = profile_dict
274+
profile = profiles_dict["a_profile"]
275+
mock_self.ProfileRenderer().render_data.return_value = profile
276+
mock_self.yaml.safe_load.return_value = profiles_dict
270277
with self.assertRaises(ValueError):
271278
_, _ = DbtParser._get_connection_creds(mock_self)
272279

273-
profile_yaml_no_credentials = """
274-
a_profile:
275-
outputs:
276-
a_target:
277-
target: a_target
278-
"""
279-
280280
@patch("builtins.open", new_callable=mock_open, read_data="")
281281
def test_get_connection_no_credentials(self, mock_open):
282-
profile_dict = {
282+
profiles_dict = {
283283
"a_profile": {
284284
"outputs": {"a_target": {}},
285285
"target": "a_target",
@@ -288,13 +288,15 @@ def test_get_connection_no_credentials(self, mock_open):
288288
mock_self = Mock()
289289
mock_self.profiles_dir = ""
290290
mock_self.project_dict = {"profile": "a_profile"}
291-
mock_self.yaml.safe_load.return_value = profile_dict
291+
mock_self.yaml.safe_load.return_value = profiles_dict
292+
profile = profiles_dict["a_profile"]
293+
mock_self.ProfileRenderer().render_data.return_value = profile
292294
with self.assertRaises(ValueError):
293295
_, _ = DbtParser._get_connection_creds(mock_self)
294296

295297
@patch("builtins.open", new_callable=mock_open, read_data="")
296298
def test_get_connection_no_target_credentials(self, mock_open):
297-
profile_dict = {
299+
profiles_dict = {
298300
"a_profile": {
299301
"outputs": {
300302
"a_target": {"type": "TYPE1", "credential_1": "credential_1", "credential_2": "credential_2"}
@@ -305,13 +307,15 @@ def test_get_connection_no_target_credentials(self, mock_open):
305307
mock_self = Mock()
306308
mock_self.profiles_dir = ""
307309
mock_self.project_dict = {"profile": "a_profile"}
308-
mock_self.yaml.safe_load.return_value = profile_dict
310+
profile = profiles_dict["a_profile"]
311+
mock_self.ProfileRenderer().render_data.return_value = profile
312+
mock_self.yaml.safe_load.return_value = profiles_dict
309313
with self.assertRaises(ValueError):
310314
_, _ = DbtParser._get_connection_creds(mock_self)
311315

312316
@patch("builtins.open", new_callable=mock_open, read_data="")
313317
def test_get_connection_no_type(self, mock_open):
314-
profile_dict = {
318+
profiles_dict = {
315319
"a_profile": {
316320
"outputs": {"a_target": {"credential_1": "credential_1", "credential_2": "credential_2"}},
317321
"target": "a_target",
@@ -320,7 +324,9 @@ def test_get_connection_no_type(self, mock_open):
320324
mock_self = Mock()
321325
mock_self.profiles_dir = ""
322326
mock_self.project_dict = {"profile": "a_profile"}
323-
mock_self.yaml.safe_load.return_value = profile_dict
327+
mock_self.yaml.safe_load.return_value = profiles_dict
328+
profile = profiles_dict["a_profile"]
329+
mock_self.ProfileRenderer().render_data.return_value = profile
324330
with self.assertRaises(ValueError):
325331
_, _ = DbtParser._get_connection_creds(mock_self)
326332

@@ -366,7 +372,7 @@ def test_local_diff(self, mock_diff_tables):
366372
mock_diff_tables.assert_called_once_with(
367373
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY
368374
)
369-
self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2)
375+
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
370376
self.assertEqual(mock_connect.call_count, 2)
371377
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
372378
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
@@ -393,7 +399,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
393399
mock_diff_tables.assert_called_once_with(
394400
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY
395401
)
396-
self.assertEqual(len(mock_diff_tables.call_args[1]['extra_columns']), 2)
402+
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
397403
self.assertEqual(mock_connect.call_count, 2)
398404
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
399405
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))

0 commit comments

Comments
 (0)