@@ -371,6 +371,10 @@ def __call__(self, ins_list): # 实现该类的__call__函数
371371from .instance import Instance
372372from .utils import pretty_table_printer
373373from .collate_fn import Collater
374+ try :
375+ from tqdm .auto import tqdm
376+ except :
377+ from .utils import _pseudo_tqdm as tqdm
374378
375379
376380class ApplyResultException (Exception ):
@@ -860,6 +864,11 @@ def apply_field(self, func, field_name, new_field_name=None, **kwargs):
860864 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
861865
862866 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
867+
868+ 4. use_tqdm: bool, 是否使用tqdm显示预处理进度
869+
870+ 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
871+
863872 :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
864873 """
865874 assert len (self ) != 0 , "Null DataSet cannot use apply_field()."
@@ -887,6 +896,10 @@ def apply_field_more(self, func, field_name, modify_fields=True, **kwargs):
887896
888897 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型
889898
899+ 4. use_tqdm: bool, 是否使用tqdm显示预处理进度
900+
901+ 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
902+
890903 :return Dict[str:Field]: 返回一个字典
891904 """
892905 assert len (self ) != 0 , "Null DataSet cannot use apply_field()."
@@ -949,6 +962,10 @@ def apply_more(self, func, modify_fields=True, **kwargs):
949962
950963 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型
951964
965+ 4. use_tqdm: bool, 是否使用tqdm显示预处理进度
966+
967+ 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
968+
952969 :return Dict[str:Field]: 返回一个字典
953970 """
954971 # 返回 dict , 检查是否一直相同
@@ -957,7 +974,9 @@ def apply_more(self, func, modify_fields=True, **kwargs):
957974 idx = - 1
958975 try :
959976 results = {}
960- for idx , ins in enumerate (self ._inner_iter ()):
977+ for idx , ins in tqdm (enumerate (self ._inner_iter ()), total = len (self ), dynamic_ncols = True ,
978+ desc = kwargs .get ('tqdm_desc' , '' ),
979+ leave = False , disable = not kwargs .get ('use_tqdm' , False )):
961980 if "_apply_field" in kwargs :
962981 res = func (ins [kwargs ["_apply_field" ]])
963982 else :
@@ -1001,6 +1020,10 @@ def apply(self, func, new_field_name=None, **kwargs):
10011020 2. is_target: bool, 如果为True则将 `new_field_name` 的field设置为target
10021021
10031022 3. ignore_type: bool, 如果为True则将 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
1023+
1024+ 4. use_tqdm: bool, 是否使用tqdm显示预处理进度
1025+
1026+ 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
10041027
10051028 :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
10061029 """
@@ -1009,7 +1032,9 @@ def apply(self, func, new_field_name=None, **kwargs):
10091032 idx = - 1
10101033 try :
10111034 results = []
1012- for idx , ins in enumerate (self ._inner_iter ()):
1035+ for idx , ins in tqdm (enumerate (self ._inner_iter ()), total = len (self ), dynamic_ncols = True , leave = False ,
1036+ desc = kwargs .get ('tqdm_desc' , '' ),
1037+ disable = not kwargs .get ('use_tqdm' , False )):
10131038 if "_apply_field" in kwargs :
10141039 results .append (func (ins [kwargs ["_apply_field" ]]))
10151040 else :
0 commit comments