Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 5f56d9f

Browse files
authored
Merge pull request #415 from dbeatty10/dbeatty10/dbt-profiles-dbt-project
Use the same logic as dbt-core for the path for the project and profiles
2 parents a614784 + b9c5e91 commit 5f56d9f

File tree

3 files changed

+54
-34
lines changed

3 files changed

+54
-34
lines changed

data_diff/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,16 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
217217
)
218218
@click.option(
219219
"--dbt-profiles-dir",
220+
envvar="DBT_PROFILES_DIR",
220221
default=None,
221222
metavar="PATH",
222-
help="Override the default dbt profile location (~/.dbt).",
223+
help="Which directory to look in for the profiles.yml file. If not set, we follow the default profiles.yml location for the dbt version being used. Can also be set via the DBT_PROFILES_DIR environment variable.",
223224
)
224225
@click.option(
225226
"--dbt-project-dir",
226227
default=None,
227228
metavar="PATH",
228-
help="Override the dbt project directory. Otherwise assumed to be the current directory.",
229+
help="Which directory to look in for the dbt_project.yml file. Default is the current working directory and its parents.",
229230
)
230231
def main(conf, run, **kw):
231232
if kw["table2"] is None and kw["database2"]:

data_diff/dbt.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dataclasses import dataclass
77
from packaging.version import parse as parse_version
88
from typing import List, Optional, Dict, Tuple
9+
from pathlib import Path
910

1011
import requests
1112

@@ -31,14 +32,30 @@ def import_dbt():
3132
from .utils import get_from_dict_with_raise, run_as_daemon, truncate_error
3233
from . import connect_to_table, diff_tables, Algorithm
3334

34-
RUN_RESULTS_PATH = "/target/run_results.json"
35-
MANIFEST_PATH = "/target/manifest.json"
36-
PROJECT_FILE = "/dbt_project.yml"
37-
PROFILES_FILE = "/profiles.yml"
35+
RUN_RESULTS_PATH = "target/run_results.json"
36+
MANIFEST_PATH = "target/manifest.json"
37+
PROJECT_FILE = "dbt_project.yml"
38+
PROFILES_FILE = "profiles.yml"
3839
LOWER_DBT_V = "1.0.0"
3940
UPPER_DBT_V = "1.4.2"
4041

4142

43+
# https://github.com/dbt-labs/dbt-core/blob/c952d44ec5c2506995fbad75320acbae49125d3d/core/dbt/cli/resolvers.py#L6
44+
def default_project_dir() -> Path:
45+
paths = list(Path.cwd().parents)
46+
paths.insert(0, Path.cwd())
47+
return next((x for x in paths if (x / PROJECT_FILE).exists()), Path.cwd())
48+
49+
50+
# https://github.com/dbt-labs/dbt-core/blob/c952d44ec5c2506995fbad75320acbae49125d3d/core/dbt/cli/resolvers.py#L12
51+
def default_profiles_dir() -> Path:
52+
return Path.cwd() if (Path.cwd() / PROFILES_FILE).exists() else Path.home() / ".dbt"
53+
54+
55+
def legacy_profiles_dir() -> Path:
56+
return Path.home() / ".dbt"
57+
58+
4259
@dataclass
4360
class DiffVars:
4461
dev_path: List[str]
@@ -251,12 +268,9 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
251268

252269

253270
class DbtParser:
254-
DEFAULT_PROFILES_DIR = os.path.expanduser("~") + "/.dbt"
255-
DEFAULT_PROJECT_DIR = os.getcwd()
256-
257271
def __init__(self, profiles_dir_override: str, project_dir_override: str, is_cloud: bool) -> None:
258-
self.profiles_dir = profiles_dir_override or self.DEFAULT_PROFILES_DIR
259-
self.project_dir = project_dir_override or self.DEFAULT_PROJECT_DIR
272+
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
273+
self.project_dir = Path(project_dir_override or default_project_dir())
260274
self.is_cloud = is_cloud
261275
self.connection = None
262276
self.project_dict = None
@@ -269,18 +283,21 @@ def get_datadiff_variables(self) -> dict:
269283
return self.project_dict.get("vars").get("data_diff")
270284

271285
def get_models(self):
272-
with open(self.project_dir + RUN_RESULTS_PATH) as run_results:
286+
with open(self.project_dir / RUN_RESULTS_PATH) as run_results:
273287
run_results_dict = json.load(run_results)
274288
run_results_obj = self.parse_run_results(run_results=run_results_dict)
275289

276290
dbt_version = parse_version(run_results_obj.metadata.dbt_version)
277291

292+
if dbt_version < parse_version("1.3.0"):
293+
self.profiles_dir = legacy_profiles_dir()
294+
278295
if dbt_version < parse_version(LOWER_DBT_V) or dbt_version >= parse_version(UPPER_DBT_V):
279296
raise Exception(
280297
f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V} and < {UPPER_DBT_V}"
281298
)
282299

283-
with open(self.project_dir + MANIFEST_PATH) as manifest:
300+
with open(self.project_dir / MANIFEST_PATH) as manifest:
284301
manifest_dict = json.load(manifest)
285302
manifest_obj = self.parse_manifest(manifest=manifest_dict)
286303

@@ -296,11 +313,11 @@ def get_primary_keys(self, model):
296313
return list((x.name for x in model.columns.values() if "primary-key" in x.tags))
297314

298315
def set_project_dict(self):
299-
with open(self.project_dir + PROJECT_FILE) as project:
316+
with open(self.project_dir / PROJECT_FILE) as project:
300317
self.project_dict = self.yaml.safe_load(project)
301318

302319
def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
303-
profiles_path = self.profiles_dir + PROFILES_FILE
320+
profiles_path = self.profiles_dir / PROFILES_FILE
304321
with open(profiles_path) as profiles:
305322
profiles = self.yaml.safe_load(profiles)
306323

tests/test_dbt.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
from pathlib import Path
34
import yaml
45
from data_diff.diff_tables import Algorithm
56
from .test_cli import run_datadiff_cli
@@ -51,7 +52,7 @@ def test_get_datadiff_variables_empty(self):
5152
def test_get_models(self, mock_open):
5253
expected_value = "expected_value"
5354
mock_self = Mock()
54-
mock_self.project_dir = ""
55+
mock_self.project_dir = Path()
5556
mock_run_results = Mock()
5657
mock_success_result = Mock()
5758
mock_failed_result = Mock()
@@ -69,47 +70,47 @@ def test_get_models(self, mock_open):
6970
models = DbtParser.get_models(mock_self)
7071

7172
self.assertEqual(expected_value, models[0])
72-
mock_open.assert_any_call(RUN_RESULTS_PATH)
73-
mock_open.assert_any_call(MANIFEST_PATH)
73+
mock_open.assert_any_call(Path(RUN_RESULTS_PATH))
74+
mock_open.assert_any_call(Path(MANIFEST_PATH))
7475
mock_self.parse_run_results.assert_called_once_with(run_results={})
7576
mock_self.parse_manifest.assert_called_once_with(manifest={})
7677

7778
@patch("builtins.open", new_callable=mock_open, read_data="{}")
7879
def test_get_models_bad_lower_dbt_version(self, mock_open):
7980
mock_self = Mock()
80-
mock_self.project_dir = ""
81+
mock_self.project_dir = Path()
8182
mock_run_results = Mock()
8283
mock_self.parse_run_results.return_value = mock_run_results
8384
mock_run_results.metadata.dbt_version = "0.19.0"
8485

8586
with self.assertRaises(Exception) as ex:
8687
DbtParser.get_models(mock_self)
8788

88-
mock_open.assert_called_once_with(RUN_RESULTS_PATH)
89+
mock_open.assert_called_once_with(Path(RUN_RESULTS_PATH))
8990
mock_self.parse_run_results.assert_called_once_with(run_results={})
9091
mock_self.parse_manifest.assert_not_called()
9192
self.assertIn("version to be", ex.exception.args[0])
9293

9394
@patch("builtins.open", new_callable=mock_open, read_data="{}")
9495
def test_get_models_bad_upper_dbt_version(self, mock_open):
9596
mock_self = Mock()
96-
mock_self.project_dir = ""
97+
mock_self.project_dir = Path()
9798
mock_run_results = Mock()
9899
mock_self.parse_run_results.return_value = mock_run_results
99100
mock_run_results.metadata.dbt_version = "1.5.1"
100101

101102
with self.assertRaises(Exception) as ex:
102103
DbtParser.get_models(mock_self)
103104

104-
mock_open.assert_called_once_with(RUN_RESULTS_PATH)
105+
mock_open.assert_called_once_with(Path(RUN_RESULTS_PATH))
105106
mock_self.parse_run_results.assert_called_once_with(run_results={})
106107
mock_self.parse_manifest.assert_not_called()
107108
self.assertIn("version to be", ex.exception.args[0])
108109

109110
@patch("builtins.open", new_callable=mock_open, read_data="{}")
110111
def test_get_models_no_success(self, mock_open):
111112
mock_self = Mock()
112-
mock_self.project_dir = ""
113+
mock_self.project_dir = Path()
113114
mock_run_results = Mock()
114115
mock_success_result = Mock()
115116
mock_failed_result = Mock()
@@ -126,21 +127,22 @@ def test_get_models_no_success(self, mock_open):
126127
with self.assertRaises(Exception):
127128
DbtParser.get_models(mock_self)
128129

129-
mock_open.assert_any_call(RUN_RESULTS_PATH)
130-
mock_open.assert_any_call(MANIFEST_PATH)
130+
mock_open.assert_any_call(Path(RUN_RESULTS_PATH))
131+
mock_open.assert_any_call(Path(MANIFEST_PATH))
131132
mock_self.parse_run_results.assert_called_once_with(run_results={})
132133
mock_self.parse_manifest.assert_called_once_with(manifest={})
133134

134135
@patch("builtins.open", new_callable=mock_open, read_data="key:\n value")
135136
def test_set_project_dict(self, mock_open):
136137
expected_dict = {"key1": "value1"}
137138
mock_self = Mock()
138-
mock_self.project_dir = ""
139+
140+
mock_self.project_dir = Path()
139141
mock_self.yaml.safe_load.return_value = expected_dict
140142
DbtParser.set_project_dict(mock_self)
141143

142144
self.assertEqual(mock_self.project_dict, expected_dict)
143-
mock_open.assert_called_once_with(PROJECT_FILE)
145+
mock_open.assert_called_once_with(Path(PROJECT_FILE))
144146

145147
def test_set_connection_snowflake_success(self):
146148
expected_driver = "snowflake"
@@ -222,7 +224,7 @@ def test_get_connection_creds_success(self, mock_open):
222224
profile = profiles_dict["a_profile"]
223225
expected_credentials = profiles_dict["a_profile"]["outputs"]["a_target"]
224226
mock_self = Mock()
225-
mock_self.profiles_dir = ""
227+
mock_self.profiles_dir = Path()
226228
mock_self.project_dict = {"profile": "a_profile"}
227229
mock_self.yaml.safe_load.return_value = profiles_dict
228230
mock_self.ProfileRenderer().render_data.return_value = profile
@@ -234,7 +236,7 @@ def test_get_connection_creds_success(self, mock_open):
234236
def test_get_connection_no_matching_profile(self, mock_open):
235237
profiles_dict = {"a_profile": {}}
236238
mock_self = Mock()
237-
mock_self.profiles_dir = ""
239+
mock_self.profiles_dir = Path()
238240
mock_self.project_dict = {"profile": "wrong_profile"}
239241
mock_self.yaml.safe_load.return_value = profiles_dict
240242
profile = profiles_dict["a_profile"]
@@ -252,7 +254,7 @@ def test_get_connection_no_target(self, mock_open):
252254
}
253255
}
254256
mock_self = Mock()
255-
mock_self.profiles_dir = ""
257+
mock_self.profiles_dir = Path()
256258
profile = profiles_dict["a_profile"]
257259
mock_self.ProfileRenderer().render_data.return_value = profile
258260
mock_self.project_dict = {"profile": "a_profile"}
@@ -269,7 +271,7 @@ def test_get_connection_no_target(self, mock_open):
269271
def test_get_connection_no_outputs(self, mock_open):
270272
profiles_dict = {"a_profile": {"target": "a_target"}}
271273
mock_self = Mock()
272-
mock_self.profiles_dir = ""
274+
mock_self.profiles_dir = Path()
273275
mock_self.project_dict = {"profile": "a_profile"}
274276
profile = profiles_dict["a_profile"]
275277
mock_self.ProfileRenderer().render_data.return_value = profile
@@ -286,7 +288,7 @@ def test_get_connection_no_credentials(self, mock_open):
286288
}
287289
}
288290
mock_self = Mock()
289-
mock_self.profiles_dir = ""
291+
mock_self.profiles_dir = Path()
290292
mock_self.project_dict = {"profile": "a_profile"}
291293
mock_self.yaml.safe_load.return_value = profiles_dict
292294
profile = profiles_dict["a_profile"]
@@ -305,7 +307,7 @@ def test_get_connection_no_target_credentials(self, mock_open):
305307
}
306308
}
307309
mock_self = Mock()
308-
mock_self.profiles_dir = ""
310+
mock_self.profiles_dir = Path()
309311
mock_self.project_dict = {"profile": "a_profile"}
310312
profile = profiles_dict["a_profile"]
311313
mock_self.ProfileRenderer().render_data.return_value = profile
@@ -322,7 +324,7 @@ def test_get_connection_no_type(self, mock_open):
322324
}
323325
}
324326
mock_self = Mock()
325-
mock_self.profiles_dir = ""
327+
mock_self.profiles_dir = Path()
326328
mock_self.project_dict = {"profile": "a_profile"}
327329
mock_self.yaml.safe_load.return_value = profiles_dict
328330
profile = profiles_dict["a_profile"]

0 commit comments

Comments
 (0)