1+ from contextlib import nullcontext
12import json
23import os
34import re
4243 run_as_daemon ,
4344 truncate_error ,
4445 print_version_info ,
46+ LogStatusHandler ,
4547)
4648
4749logger = getLogger (__name__ )
@@ -67,6 +69,7 @@ def dbt_diff(
6769 dbt_selection : Optional [str ] = None ,
6870 json_output : bool = False ,
6971 state : Optional [str ] = None ,
72+ log_status_handler : Optional [LogStatusHandler ] = None ,
7073) -> None :
7174 print_version_info ()
7275 diff_threads = []
@@ -88,7 +91,6 @@ def dbt_diff(
8891 if not api :
8992 return
9093 org_meta = api .get_org_meta ()
91-
9294 if config .datasource_id is None :
9395 rich .print ("[red]Data source ID not found in dbt_project.yml" )
9496 raise DataDiffNoDatasourceIdError (
@@ -102,48 +104,54 @@ def dbt_diff(
102104 else :
103105 dbt_parser .set_connection ()
104106
105- for model in models :
106- diff_vars = _get_diff_vars (dbt_parser , config , model )
107-
108- # we won't always have a prod path when using state
109- # when the model DNE in prod manifest, skip the model diff
110- if (
111- state and len (diff_vars .prod_path ) < 2
112- ): # < 2 because some providers like databricks can legitimately have *only* 2
113- diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
114- diff_output_str += "[green]New model: nothing to diff![/] \n "
115- rich .print (diff_output_str )
116- continue
117-
118- if diff_vars .primary_keys :
119- if is_cloud :
120- diff_thread = run_as_daemon (_cloud_diff , diff_vars , config .datasource_id , api , org_meta )
121- diff_threads .append (diff_thread )
122- else :
123- _local_diff (diff_vars , json_output )
124- else :
125- if json_output :
126- print (
127- json .dumps (
128- jsonify_error (
129- table1 = diff_vars .prod_path ,
130- table2 = diff_vars .dev_path ,
131- dbt_model = diff_vars .dbt_model ,
132- error = "No primary key found. Add uniqueness tests, meta, or tags." ,
133- )
134- ),
135- flush = True ,
136- )
107+ with log_status_handler .status if log_status_handler else nullcontext ():
108+ for model in models :
109+ if log_status_handler :
110+ log_status_handler .set_prefix (f"Diffing { model .alias } \n " )
111+
112+ diff_vars = _get_diff_vars (dbt_parser , config , model )
113+
114+ # we won't always have a prod path when using state
115+ # when the model DNE in prod manifest, skip the model diff
116+ if (
117+ state and len (diff_vars .prod_path ) < 2
118+ ): # < 2 because some providers like databricks can legitimately have *only* 2
119+ diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
120+ diff_output_str += "[green]New model: nothing to diff![/] \n "
121+ rich .print (diff_output_str )
122+ continue
123+
124+ if diff_vars .primary_keys :
125+ if is_cloud :
126+ diff_thread = run_as_daemon (
127+ _cloud_diff , diff_vars , config .datasource_id , api , org_meta , log_status_handler
128+ )
129+ diff_threads .append (diff_thread )
130+ else :
131+ _local_diff (diff_vars , json_output )
137132 else :
138- rich .print (
139- _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
140- + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
141- )
142-
143- # wait for all threads
144- if diff_threads :
145- for thread in diff_threads :
146- thread .join ()
133+ if json_output :
134+ print (
135+ json .dumps (
136+ jsonify_error (
137+ table1 = diff_vars .prod_path ,
138+ table2 = diff_vars .dev_path ,
139+ dbt_model = diff_vars .dbt_model ,
140+ error = "No primary key found. Add uniqueness tests, meta, or tags." ,
141+ )
142+ ),
143+ flush = True ,
144+ )
145+ else :
146+ rich .print (
147+ _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
148+ + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
149+ )
150+
151+ # wait for all threads
152+ if diff_threads :
153+ for thread in diff_threads :
154+ thread .join ()
147155
148156
149157def _get_diff_vars (
@@ -345,7 +353,15 @@ def _initialize_api() -> Optional[DatafoldAPI]:
345353 return DatafoldAPI (api_key = api_key , host = datafold_host )
346354
347355
348- def _cloud_diff (diff_vars : TDiffVars , datasource_id : int , api : DatafoldAPI , org_meta : TCloudApiOrgMeta ) -> None :
356+ def _cloud_diff (
357+ diff_vars : TDiffVars ,
358+ datasource_id : int ,
359+ api : DatafoldAPI ,
360+ org_meta : TCloudApiOrgMeta ,
361+ log_status_handler : Optional [LogStatusHandler ] = None ,
362+ ) -> None :
363+ if log_status_handler :
364+ log_status_handler .cloud_diff_started (diff_vars .dev_path [- 1 ])
349365 diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
350366 payload = TCloudApiDataDiff (
351367 data_source1_id = datasource_id ,
@@ -414,6 +430,8 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_
414430 diff_output_str += f"\n { diff_url } \n { no_differences_template ()} \n "
415431 rich .print (diff_output_str )
416432
433+ if log_status_handler :
434+ log_status_handler .cloud_diff_finished (diff_vars .dev_path [- 1 ])
417435 except BaseException as ex : # Catch KeyboardInterrupt too
418436 error = ex
419437 finally :
0 commit comments