@@ -3342,33 +3342,57 @@ def test_d4rl_iteration(self, task, split_trajs):
33423342_MINARI_DATASETS = []
33433343
33443344
3345- def _minari_selected_datasets ():
3346- if not _has_minari or not _has_gymnasium :
3347- return
3345+ def _minari_init ():
3346+ """Initialize Minari datasets list. Returns True if already initialized."""
33483347 global _MINARI_DATASETS
3349- import minari
3348+ if _MINARI_DATASETS and not all (
3349+ isinstance (x , str ) and x .isdigit () for x in _MINARI_DATASETS
3350+ ):
3351+ return True # Already initialized with real dataset names
33503352
3351- torch .manual_seed (0 )
3353+ if not _has_minari or not _has_gymnasium :
3354+ return False
33523355
3353- total_keys = sorted (
3354- minari .list_remote_datasets (latest_version = True , compatible_minari_version = True )
3355- )
3356- indices = torch .randperm (len (total_keys ))[:20 ]
3357- keys = [total_keys [idx ] for idx in indices ]
3356+ try :
3357+ import minari
3358+
3359+ torch .manual_seed (0 )
33583360
3359- assert len (keys ) > 5 , keys
3360- _MINARI_DATASETS += keys
3361+ total_keys = sorted (
3362+ minari .list_remote_datasets (
3363+ latest_version = True , compatible_minari_version = True
3364+ )
3365+ )
3366+ indices = torch .randperm (len (total_keys ))[:20 ]
3367+ keys = [total_keys [idx ] for idx in indices ]
33613368
3369+ assert len (keys ) > 5 , keys
3370+ _MINARI_DATASETS [:] = keys # Replace the placeholder values
3371+ return True
3372+ except Exception :
3373+ return False
33623374
3363- _minari_selected_datasets ()
3375+
3376+ # Initialize with placeholder values for parametrization
3377+ # These will be replaced with actual dataset names when the first Minari test runs
3378+ _MINARI_DATASETS = [str (i ) for i in range (20 )]
33643379
33653380
33663381@pytest .mark .skipif (not _has_minari or not _has_gymnasium , reason = "Minari not found" )
33673382@pytest .mark .slow
33683383class TestMinari :
33693384 @pytest .mark .parametrize ("split" , [False , True ])
3370- @pytest .mark .parametrize ("selected_dataset" , _MINARI_DATASETS )
3371- def test_load (self , selected_dataset , split ):
3385+ @pytest .mark .parametrize ("dataset_idx" , range (20 ))
3386+ def test_load (self , dataset_idx , split ):
3387+ # Initialize Minari datasets if not already done
3388+ if not _minari_init ():
3389+ pytest .skip ("Failed to initialize Minari datasets" )
3390+
3391+ # Get the actual dataset name from the initialized list
3392+ if dataset_idx >= len (_MINARI_DATASETS ):
3393+ pytest .skip (f"Dataset index { dataset_idx } out of range" )
3394+
3395+ selected_dataset = _MINARI_DATASETS [dataset_idx ]
33723396 torchrl_logger .info (f"dataset { selected_dataset } " )
33733397 data = MinariExperienceReplay (
33743398 selected_dataset , batch_size = 32 , split_trajs = split
0 commit comments