@@ -63,6 +63,7 @@ class DiffVars:
6363 primary_keys : List [str ]
6464 datasource_id : str
6565 connection : Dict [str , str ]
66+ threads : Optional [int ]
6667
6768
6869def dbt_diff (
@@ -127,16 +128,16 @@ def _get_diff_vars(
127128 dev_qualified_list = [dev_database , dev_schema , model .alias ]
128129 prod_qualified_list = [prod_database , prod_schema , model .alias ]
129130
130- return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection )
131+ return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection , dbt_parser . threads )
131132
132133
133134def _local_diff (diff_vars : DiffVars ) -> None :
134135 column_diffs_str = ""
135136 dev_qualified_string = "." .join (diff_vars .dev_path )
136137 prod_qualified_string = "." .join (diff_vars .prod_path )
137138
138- table1 = connect_to_table (diff_vars .connection , dev_qualified_string , tuple (diff_vars .primary_keys ))
139- table2 = connect_to_table (diff_vars .connection , prod_qualified_string , tuple (diff_vars .primary_keys ))
139+ table1 = connect_to_table (diff_vars .connection , dev_qualified_string , tuple (diff_vars .primary_keys ), diff_vars . threads )
140+ table2 = connect_to_table (diff_vars .connection , prod_qualified_string , tuple (diff_vars .primary_keys ), diff_vars . threads )
140141
141142 table1_columns = list (table1 .get_schema ())
142143 try :
@@ -274,6 +275,7 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
274275 self .connection = None
275276 self .project_dict = None
276277 self .requires_upper = False
278+ self .threads = None
277279
278280 self .parse_run_results , self .parse_manifest , self .ProfileRenderer , self .yaml = import_dbt ()
279281
@@ -319,95 +321,100 @@ def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
319321 with open (profiles_path ) as profiles :
320322 profiles = self .yaml .safe_load (profiles )
321323
322- dbt_profile = self .project_dict .get ("profile" )
324+ dbt_profile_var = self .project_dict .get ("profile" )
323325
324- profile_outputs = get_from_dict_with_raise (
325- profiles , dbt_profile , f"No profile '{ dbt_profile } ' found in '{ profiles_path } '."
326+ profile = get_from_dict_with_raise (
327+ profiles , dbt_profile_var , f"No profile '{ dbt_profile_var } ' found in '{ profiles_path } '."
326328 )
329+ # values can contain env_vars
330+ rendered_profile = self .ProfileRenderer ().render_data (profile )
327331 profile_target = get_from_dict_with_raise (
328- profile_outputs , "target" , f"No target found in profile '{ dbt_profile } ' in '{ profiles_path } '."
332+ rendered_profile , "target" , f"No target found in profile '{ dbt_profile_var } ' in '{ profiles_path } '."
329333 )
330334 outputs = get_from_dict_with_raise (
331- profile_outputs , "outputs" , f"No outputs found in profile '{ dbt_profile } ' in '{ profiles_path } '."
335+ rendered_profile , "outputs" , f"No outputs found in profile '{ dbt_profile_var } ' in '{ profiles_path } '."
332336 )
333337 credentials = get_from_dict_with_raise (
334338 outputs ,
335339 profile_target ,
336- f"No credentials found for target '{ profile_target } ' in profile '{ dbt_profile } ' in '{ profiles_path } '." ,
340+ f"No credentials found for target '{ profile_target } ' in profile '{ dbt_profile_var } ' in '{ profiles_path } '." ,
337341 )
338342 conn_type = get_from_dict_with_raise (
339343 credentials ,
340344 "type" ,
341- f"No type found for target '{ profile_target } ' in profile '{ dbt_profile } ' in '{ profiles_path } '." ,
345+ f"No type found for target '{ profile_target } ' in profile '{ dbt_profile_var } ' in '{ profiles_path } '." ,
342346 )
343347 conn_type = conn_type .lower ()
344348
345- # values can contain env_vars
346- rendered_credentials = self .ProfileRenderer ().render_data (credentials )
347- return rendered_credentials , conn_type
349+ return credentials , conn_type
348350
349351 def set_connection (self ):
350- rendered_credentials , conn_type = self ._get_connection_creds ()
352+ credentials , conn_type = self ._get_connection_creds ()
351353
352354 if conn_type == "snowflake" :
353- if rendered_credentials .get ("password" ) is None or rendered_credentials .get ("private_key_path" ) is not None :
355+ if credentials .get ("password" ) is None or credentials .get ("private_key_path" ) is not None :
354356 raise Exception ("Only password authentication is currently supported for Snowflake." )
355357 conn_info = {
356358 "driver" : conn_type ,
357- "user" : rendered_credentials .get ("user" ),
358- "password" : rendered_credentials .get ("password" ),
359- "account" : rendered_credentials .get ("account" ),
360- "database" : rendered_credentials .get ("database" ),
361- "warehouse" : rendered_credentials .get ("warehouse" ),
362- "role" : rendered_credentials .get ("role" ),
363- "schema" : rendered_credentials .get ("schema" ),
359+ "user" : credentials .get ("user" ),
360+ "password" : credentials .get ("password" ),
361+ "account" : credentials .get ("account" ),
362+ "database" : credentials .get ("database" ),
363+ "warehouse" : credentials .get ("warehouse" ),
364+ "role" : credentials .get ("role" ),
365+ "schema" : credentials .get ("schema" ),
364366 }
367+ self .threads = credentials .get ("threads" )
365368 self .requires_upper = True
366369 elif conn_type == "bigquery" :
367- method = rendered_credentials .get ("method" )
370+ method = credentials .get ("method" )
368371 # there are many connection types https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#oauth-via-gcloud
369372 # this assumes that the user is auth'd via `gcloud auth application-default login`
370373 if method is None or method != "oauth" :
371374 raise Exception ("Oauth is the current method supported for Big Query." )
372375 conn_info = {
373376 "driver" : conn_type ,
374- "project" : rendered_credentials .get ("project" ),
375- "dataset" : rendered_credentials .get ("dataset" ),
377+ "project" : credentials .get ("project" ),
378+ "dataset" : credentials .get ("dataset" ),
376379 }
380+ self .threads = credentials .get ("threads" )
377381 elif conn_type == "duckdb" :
378382 conn_info = {
379383 "driver" : conn_type ,
380- "filepath" : rendered_credentials .get ("path" ),
384+ "filepath" : credentials .get ("path" ),
381385 }
382386 elif conn_type == "redshift" :
383- if rendered_credentials .get ("password" ) is None or rendered_credentials .get ("method" ) == "iam" :
387+ if credentials .get ("password" ) is None or credentials .get ("method" ) == "iam" :
384388 raise Exception ("Only password authentication is currently supported for Redshift." )
385389 conn_info = {
386390 "driver" : conn_type ,
387- "host" : rendered_credentials .get ("host" ),
388- "user" : rendered_credentials .get ("user" ),
389- "password" : rendered_credentials .get ("password" ),
390- "port" : rendered_credentials .get ("port" ),
391- "dbname" : rendered_credentials .get ("dbname" ),
391+ "host" : credentials .get ("host" ),
392+ "user" : credentials .get ("user" ),
393+ "password" : credentials .get ("password" ),
394+ "port" : credentials .get ("port" ),
395+ "dbname" : credentials .get ("dbname" ),
392396 }
397+ self .threads = credentials .get ("threads" )
393398 elif conn_type == "databricks" :
394399 conn_info = {
395400 "driver" : conn_type ,
396- "catalog" : rendered_credentials .get ("catalog" ),
397- "server_hostname" : rendered_credentials .get ("host" ),
398- "http_path" : rendered_credentials .get ("http_path" ),
399- "schema" : rendered_credentials .get ("schema" ),
400- "access_token" : rendered_credentials .get ("token" ),
401+ "catalog" : credentials .get ("catalog" ),
402+ "server_hostname" : credentials .get ("host" ),
403+ "http_path" : credentials .get ("http_path" ),
404+ "schema" : credentials .get ("schema" ),
405+ "access_token" : credentials .get ("token" ),
401406 }
407+ self .threads = credentials .get ("threads" )
402408 elif conn_type == "postgres" :
403409 conn_info = {
404410 "driver" : "postgresql" ,
405- "host" : rendered_credentials .get ("host" ),
406- "user" : rendered_credentials .get ("user" ),
407- "password" : rendered_credentials .get ("password" ),
408- "port" : rendered_credentials .get ("port" ),
409- "dbname" : rendered_credentials .get ("dbname" ) or rendered_credentials .get ("database" ),
411+ "host" : credentials .get ("host" ),
412+ "user" : credentials .get ("user" ),
413+ "password" : credentials .get ("password" ),
414+ "port" : credentials .get ("port" ),
415+ "dbname" : credentials .get ("dbname" ) or credentials .get ("database" ),
410416 }
417+ self .threads = credentials .get ("threads" )
411418 else :
412419 raise NotImplementedError (f"Provider { conn_type } is not yet supported for dbt diffs" )
413420
0 commit comments