3838)
3939
4040from ..datasets import load_dataset
41+ from ..datasets .special_params import assign_case_special_values_on_run
4142from ..datasets .transformer import split_and_transform_data
4243from ..utils .bench_case import get_bench_case_value
4344from ..utils .common import convert_to_numpy , custom_format , get_module_members
@@ -511,8 +512,6 @@ def measure_sklearn_estimator(
511512 bench_case ,
512513 task ,
513514 estimator_class ,
514- estimator_methods ,
515- estimator_params ,
516515):
517516 enable_modelbuilders = get_bench_case_value (
518517 bench_case , "algorithm:enable_modelbuilders" , False
@@ -530,17 +529,31 @@ def measure_sklearn_estimator(
530529 )
531530 sklearnex_logging_stream = get_sklearnex_logging_stream ()
532531
532+ is_dataset_sequence = (
533+ get_bench_case_value (bench_case , "data:dataset_sequence" ) is not None
534+ )
535+ # TODO Consider if it is possible to do without additional dataset loading
536+ if not is_dataset_sequence :
537+ dataset_info = get_bench_case_value (bench_case , "data" )
538+ data , data_descriptor = load_dataset (bench_case , dataset_info )
539+ assign_case_special_values_on_run (bench_case , data , data_descriptor )
540+
541+ # get estimator parameters
542+ estimator_params = get_bench_case_value (
543+ bench_case , "algorithm:estimator_params" , dict ()
544+ )
545+
546+ # get estimator methods for measurement
547+ estimator_methods = get_estimator_methods (bench_case )
548+
533549 metrics = dict ()
550+
534551 estimator_instance = estimator_class (** estimator_params )
535552 for stage in estimator_methods .keys ():
536553 for method in estimator_methods [stage ]:
537554 if hasattr (estimator_instance , method ):
538555 method_instance = getattr (estimator_instance , method )
539556 if method == "partial_fit" :
540- is_dataset_sequence = (
541- get_bench_case_value (bench_case , "data:dataset_sequence" )
542- is not None
543- )
544557 if is_dataset_sequence :
545558 function_to_measure = create_online_function_for_big_data (
546559 bench_case , estimator_instance , method_instance , stage
@@ -606,14 +619,6 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
606619 estimator_class = get_estimator (library_name , estimator_name )
607620 task = estimator_to_task (estimator_name )
608621
609- # get estimator parameters
610- estimator_params = get_bench_case_value (
611- bench_case , "algorithm:estimator_params" , dict ()
612- )
613-
614- # get estimator methods for measurement
615- estimator_methods = get_estimator_methods (bench_case )
616-
617622 # benchmark case filtering
618623 if not bench_case_filter (bench_case , filters ):
619624 logger .warning ("Benchmarking case was filtered." )
@@ -626,8 +631,6 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
626631 bench_case ,
627632 task ,
628633 estimator_class ,
629- estimator_methods ,
630- estimator_params ,
631634 )
632635
633636 result_template = {
@@ -648,6 +651,7 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
648651 "training" : data_description ["x_train" ],
649652 "inference" : data_description ["x_test" ],
650653 }
654+ estimator_methods = get_estimator_methods (bench_case )
651655 for stage in estimator_methods .keys ():
652656 data_descs [stage ].update (
653657 {
0 commit comments