1+ from contextlib import nullcontext
12import json
23import os
34import re
4243 run_as_daemon ,
4344 truncate_error ,
4445 print_version_info ,
46+ LogStatusHandler ,
4547)
4648
4749logger = getLogger (__name__ )
@@ -67,6 +69,7 @@ def dbt_diff(
6769 dbt_selection : Optional [str ] = None ,
6870 json_output : bool = False ,
6971 state : Optional [str ] = None ,
72+ log_status_handler : Optional [LogStatusHandler ] = None ,
7073 where_flag : Optional [str ] = None ,
7174) -> None :
7275 print_version_info ()
@@ -89,7 +92,6 @@ def dbt_diff(
8992 if not api :
9093 return
9194 org_meta = api .get_org_meta ()
92-
9395 if config .datasource_id is None :
9496 rich .print ("[red]Data source ID not found in dbt_project.yml" )
9597 raise DataDiffNoDatasourceIdError (
@@ -103,48 +105,54 @@ def dbt_diff(
103105 else :
104106 dbt_parser .set_connection ()
105107
106- for model in models :
107- diff_vars = _get_diff_vars (dbt_parser , config , model , where_flag )
108-
109- # we won't always have a prod path when using state
110- # when the model DNE in prod manifest, skip the model diff
111- if (
112- state and len (diff_vars .prod_path ) < 2
113- ): # < 2 because some providers like databricks can legitimately have *only* 2
114- diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
115- diff_output_str += "[green]New model: nothing to diff![/] \n "
116- rich .print (diff_output_str )
117- continue
118-
119- if diff_vars .primary_keys :
120- if is_cloud :
121- diff_thread = run_as_daemon (_cloud_diff , diff_vars , config .datasource_id , api , org_meta )
122- diff_threads .append (diff_thread )
123- else :
124- _local_diff (diff_vars , json_output )
125- else :
126- if json_output :
127- print (
128- json .dumps (
129- jsonify_error (
130- table1 = diff_vars .prod_path ,
131- table2 = diff_vars .dev_path ,
132- dbt_model = diff_vars .dbt_model ,
133- error = "No primary key found. Add uniqueness tests, meta, or tags." ,
134- )
135- ),
136- flush = True ,
137- )
108+ with log_status_handler .status if log_status_handler else nullcontext ():
109+ for model in models :
110+ if log_status_handler :
111+ log_status_handler .set_prefix (f"Diffing { model .alias } \n " )
112+
113+ diff_vars = _get_diff_vars (dbt_parser , config , model , where_flag )
114+
115+ # we won't always have a prod path when using state
116+ # when the model DNE in prod manifest, skip the model diff
117+ if (
118+ state and len (diff_vars .prod_path ) < 2
119+ ): # < 2 because some providers like databricks can legitimately have *only* 2
120+ diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
121+ diff_output_str += "[green]New model: nothing to diff![/] \n "
122+ rich .print (diff_output_str )
123+ continue
124+
125+ if diff_vars .primary_keys :
126+ if is_cloud :
127+ diff_thread = run_as_daemon (
128+ _cloud_diff , diff_vars , config .datasource_id , api , org_meta , log_status_handler
129+ )
130+ diff_threads .append (diff_thread )
131+ else :
132+ _local_diff (diff_vars , json_output )
138133 else :
139- rich .print (
140- _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
141- + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
142- )
143-
144- # wait for all threads
145- if diff_threads :
146- for thread in diff_threads :
147- thread .join ()
134+ if json_output :
135+ print (
136+ json .dumps (
137+ jsonify_error (
138+ table1 = diff_vars .prod_path ,
139+ table2 = diff_vars .dev_path ,
140+ dbt_model = diff_vars .dbt_model ,
141+ error = "No primary key found. Add uniqueness tests, meta, or tags." ,
142+ )
143+ ),
144+ flush = True ,
145+ )
146+ else :
147+ rich .print (
148+ _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
149+ + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
150+ )
151+
152+ # wait for all threads
153+ if diff_threads :
154+ for thread in diff_threads :
155+ thread .join ()
148156
149157
150158def _get_diff_vars (
@@ -348,7 +356,15 @@ def _initialize_api() -> Optional[DatafoldAPI]:
348356 return DatafoldAPI (api_key = api_key , host = datafold_host )
349357
350358
351- def _cloud_diff (diff_vars : TDiffVars , datasource_id : int , api : DatafoldAPI , org_meta : TCloudApiOrgMeta ) -> None :
359+ def _cloud_diff (
360+ diff_vars : TDiffVars ,
361+ datasource_id : int ,
362+ api : DatafoldAPI ,
363+ org_meta : TCloudApiOrgMeta ,
364+ log_status_handler : Optional [LogStatusHandler ] = None ,
365+ ) -> None :
366+ if log_status_handler :
367+ log_status_handler .cloud_diff_started (diff_vars .dev_path [- 1 ])
352368 diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
353369 payload = TCloudApiDataDiff (
354370 data_source1_id = datasource_id ,
@@ -417,6 +433,8 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_
417433 diff_output_str += f"\n { diff_url } \n { no_differences_template ()} \n "
418434 rich .print (diff_output_str )
419435
436+ if log_status_handler :
437+ log_status_handler .cloud_diff_finished (diff_vars .dev_path [- 1 ])
420438 except BaseException as ex : # Catch KeyboardInterrupt too
421439 error = ex
422440 finally :
0 commit comments