@@ -19,7 +19,7 @@ class AnnotatedObjectsDataset(Dataset):
1919 def __init__ (self , data_path : Union [str , Path ], split : SplitType , keys : List [str ], target_image_size : int ,
2020 min_object_area : float , min_objects_per_image : int , max_objects_per_image : int ,
2121 crop_method : CropMethodType , random_flip : bool , no_tokens : int , use_group_parameter : bool ,
22- encode_crop : bool , category_allow_list_target : str , category_mapping_target : str ,
22+ encode_crop : bool , category_allow_list_target : str = "" , category_mapping_target : str = "" ,
2323 no_object_classes : Optional [int ] = None ):
2424 self .data_path = data_path
2525 self .split = split
@@ -43,6 +43,7 @@ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str
4343 self .transform_functions : List [Callable ] = self .setup_transform (target_image_size , crop_method , random_flip )
4444 self .paths = self .build_paths (self .data_path )
4545 self ._conditional_builders = None
46+ self .category_allow_list = None
4647 if category_allow_list_target :
4748 allow_list = load_object_from_string (category_allow_list_target )
4849 self .category_allow_list = {name for name , _ in allow_list }
0 commit comments