1212from parameterized import parameterized
1313
1414from data_diff import databases as db
15+ from data_diff .utils import number_to_human
1516from data_diff .diff_tables import TableDiffer , TableSegment , DEFAULT_BISECTION_THRESHOLD
16- from .common import CONN_STRINGS , N_SAMPLES , BENCHMARK , GIT_REVISION , random_table_suffix
17+ from .common import CONN_STRINGS , N_SAMPLES , N_THREADS , BENCHMARK , GIT_REVISION , random_table_suffix
1718
1819
19- CONNS = {k : db .connect_to_uri (v , 1 ) for k , v in CONN_STRINGS .items ()}
20+ CONNS = {k : db .connect_to_uri (v , N_THREADS ) for k , v in CONN_STRINGS .items ()}
2021
2122CONNS [db .MySQL ].query ("SET @@session.time_zone='+00:00'" , None )
2223
@@ -258,7 +259,6 @@ def __iter__(self):
258259 "int" : [
259260 # all 38 digits with 0 precision, don't need to test all
260261 "int" ,
261- "integer" ,
262262 "bigint" ,
263263 # "smallint",
264264 # "tinyint",
@@ -385,17 +385,6 @@ def sanitize(name):
385385 return parameterized .to_safe_name (name )
386386
387387
388- def number_to_human (n ):
389- millnames = ["" , "k" , "m" , "b" ]
390- n = float (n )
391- millidx = max (
392- 0 ,
393- min (len (millnames ) - 1 , int (math .floor (0 if n == 0 else math .log10 (abs (n )) / 3 ))),
394- )
395-
396- return "{:.0f}{}" .format (n / 10 ** (3 * millidx ), millnames [millidx ])
397-
398-
399388# Pass --verbose to test run to get a nice output.
400389def expand_params (testcase_func , param_num , param ):
401390 source_db , target_db , source_type , target_type , type_category = param .args
@@ -431,6 +420,10 @@ def _insert_to_table(conn, table, values, type):
431420 if isinstance (conn , db .Oracle ):
432421 default_insertion_query = f"INSERT INTO { table } (id, col)"
433422
423+ batch_size = 8000
424+ if isinstance (conn , db .BigQuery ):
425+ batch_size = 1000
426+
434427 insertion_query = default_insertion_query
435428 selects = []
436429 for j , sample in values :
@@ -453,7 +446,7 @@ def _insert_to_table(conn, table, values, type):
453446
454447 # Some databases want small batch sizes...
455448 # Need to also insert on the last row, might not divide cleanly!
456- if j % 8000 == 0 or j == N_SAMPLES :
449+ if j % batch_size == 0 or j == N_SAMPLES :
457450 if isinstance (conn , db .Oracle ):
458451 insertion_query += " UNION ALL " .join (selects )
459452 conn .query (insertion_query , None )
@@ -594,7 +587,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
594587 # configuration with each segment being ~250k rows.
595588 ch_factor = min (max (int (N_SAMPLES / 250_000 ), 2 ), 128 ) if BENCHMARK else 2
596589 ch_threshold = min (DEFAULT_BISECTION_THRESHOLD , int (N_SAMPLES / ch_factor )) if BENCHMARK else 3
597- ch_threads = 1
590+ ch_threads = N_THREADS
598591 differ = TableDiffer (
599592 bisection_threshold = ch_threshold ,
600593 bisection_factor = ch_factor ,
@@ -615,7 +608,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
615608 # parallel, using the existing implementation.
616609 dl_factor = max (int (N_SAMPLES / 100_000 ), 2 ) if BENCHMARK else 2
617610 dl_threshold = int (N_SAMPLES / dl_factor ) + 1 if BENCHMARK else math .inf
618- dl_threads = 1
611+ dl_threads = N_THREADS
619612 differ = TableDiffer (
620613 bisection_threshold = dl_threshold , bisection_factor = dl_factor , max_threadpool_size = dl_threads
621614 )
@@ -634,6 +627,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
634627 "git_revision" : GIT_REVISION ,
635628 "rows" : N_SAMPLES ,
636629 "rows_human" : number_to_human (N_SAMPLES ),
630+ "name_human" : f"{ source_db .__name__ } /{ sanitize (source_type )} <-> { target_db .__name__ } /{ sanitize (target_type )} " ,
637631 "src_table" : src_table [1 :- 1 ], # remove quotes
638632 "target_table" : dst_table [1 :- 1 ],
639633 "source_type" : source_type ,
@@ -642,6 +636,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
642636 "insertion_target_sec" : round (insertion_target_duration , 3 ),
643637 "count_source_sec" : round (count_source_duration , 3 ),
644638 "count_target_sec" : round (count_target_duration , 3 ),
639+ "count_max_sec" : max (round (count_target_duration , 3 ), round (count_source_duration , 3 )),
645640 "checksum_sec" : round (checksum_duration , 3 ),
646641 "download_sec" : round (download_duration , 3 ),
647642 "download_bisection_factor" : dl_factor ,
@@ -655,7 +650,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
655650 if BENCHMARK :
656651 print (json .dumps (result , indent = 2 ))
657652 file_name = f"benchmark_{ GIT_REVISION } .jsonl"
658- with open (file_name , "a" ) as file :
653+ with open (file_name , "a" , encoding = "utf-8" ) as file :
659654 file .write (json .dumps (result ) + "\n " )
660655 file .flush ()
661656 print (f"Written to { file_name } " )
0 commit comments