Skip to content

Commit 04ad8e6

Browse files
committed
DataSet apply的时候可以传入use_tqdm和tqdm_desc
1 parent 4886dbf commit 04ad8e6

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

fastNLP/core/dataset.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,10 @@ def __call__(self, ins_list): # 实现该类的__call__函数
371371
from .instance import Instance
372372
from .utils import pretty_table_printer
373373
from .collate_fn import Collater
374+
try:
375+
from tqdm.auto import tqdm
376+
except:
377+
from .utils import _pseudo_tqdm as tqdm
374378

375379

376380
class ApplyResultException(Exception):
@@ -860,6 +864,11 @@ def apply_field(self, func, field_name, new_field_name=None, **kwargs):
860864
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
861865
862866
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
867+
868+
4. use_tqdm: bool, 是否使用tqdm显示预处理进度
869+
870+
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
871+
863872
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
864873
"""
865874
assert len(self) != 0, "Null DataSet cannot use apply_field()."
@@ -887,6 +896,10 @@ def apply_field_more(self, func, field_name, modify_fields=True, **kwargs):
887896
888897
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型
889898
899+
4. use_tqdm: bool, 是否使用tqdm显示预处理进度
900+
901+
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
902+
890903
:return Dict[str:Field]: 返回一个字典
891904
"""
892905
assert len(self) != 0, "Null DataSet cannot use apply_field()."
@@ -949,6 +962,10 @@ def apply_more(self, func, modify_fields=True, **kwargs):
949962
950963
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型
951964
965+
4. use_tqdm: bool, 是否使用tqdm显示预处理进度
966+
967+
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
968+
952969
:return Dict[str:Field]: 返回一个字典
953970
"""
954971
# 返回 dict , 检查是否一直相同
@@ -957,7 +974,9 @@ def apply_more(self, func, modify_fields=True, **kwargs):
957974
idx = -1
958975
try:
959976
results = {}
960-
for idx, ins in enumerate(self._inner_iter()):
977+
for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True,
978+
desc=kwargs.get('tqdm_desc', ''),
979+
leave=False, disable=not kwargs.get('use_tqdm', False)):
961980
if "_apply_field" in kwargs:
962981
res = func(ins[kwargs["_apply_field"]])
963982
else:
@@ -1001,6 +1020,10 @@ def apply(self, func, new_field_name=None, **kwargs):
10011020
2. is_target: bool, 如果为True则将 `new_field_name` 的field设置为target
10021021
10031022
3. ignore_type: bool, 如果为True则将 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
1023+
1024+
4. use_tqdm: bool, 是否使用tqdm显示预处理进度
1025+
1026+
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
10041027
10051028
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
10061029
"""
@@ -1009,7 +1032,9 @@ def apply(self, func, new_field_name=None, **kwargs):
10091032
idx = -1
10101033
try:
10111034
results = []
1012-
for idx, ins in enumerate(self._inner_iter()):
1035+
for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True, leave=False,
1036+
desc=kwargs.get('tqdm_desc', ''),
1037+
disable=not kwargs.get('use_tqdm', False)):
10131038
if "_apply_field" in kwargs:
10141039
results.append(func(ins[kwargs["_apply_field"]]))
10151040
else:

fastNLP/io/data_bundle.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,15 @@ def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_da
321321
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
322322
323323
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
324+
325+
4. use_tqdm: bool, 是否显示tqdm进度条
326+
327+
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
324328
"""
329+
tqdm_desc = kwargs.get('tqdm_desc', '')
325330
for name, dataset in self.datasets.items():
331+
if tqdm_desc != '':
332+
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
326333
if dataset.has_field(field_name=field_name):
327334
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs)
328335
elif not ignore_miss_dataset:
@@ -350,10 +357,17 @@ def apply_field_more(self, func, field_name, modify_fields=True, ignore_miss_dat
350357
351358
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型
352359
360+
4. use_tqdm: bool, 是否显示tqdm进度条
361+
362+
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
363+
353364
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
354365
"""
355366
res = {}
367+
tqdm_desc = kwargs.get('tqdm_desc', '')
356368
for name, dataset in self.datasets.items():
369+
if tqdm_desc != '':
370+
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
357371
if dataset.has_field(field_name=field_name):
358372
res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs)
359373
elif not ignore_miss_dataset:
@@ -376,8 +390,16 @@ def apply(self, func, new_field_name: str, **kwargs):
376390
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
377391
378392
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
393+
394+
4. use_tqdm: bool, 是否显示tqdm进度条
395+
396+
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
397+
379398
"""
399+
tqdm_desc = kwargs.get('tqdm_desc', '')
380400
for name, dataset in self.datasets.items():
401+
if tqdm_desc != '':
402+
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
381403
dataset.apply(func, new_field_name=new_field_name, **kwargs)
382404
return self
383405

@@ -399,10 +421,17 @@ def apply_more(self, func, modify_fields=True, **kwargs):
399421
400422
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型
401423
424+
4. use_tqdm: bool, 是否显示tqdm进度条
425+
426+
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
427+
402428
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
403429
"""
404430
res = {}
431+
tqdm_desc = kwargs.get('tqdm_desc', '')
405432
for name, dataset in self.datasets.items():
433+
if tqdm_desc!='':
434+
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
406435
res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs)
407436
return res
408437

tests/core/test_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ def test_apply(self):
136136
ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True)
137137
# expect no exception raised
138138

139+
def test_apply_tqdm(self):
140+
import time
141+
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
142+
def do_nothing(ins):
143+
time.sleep(0.01)
144+
ds.apply(do_nothing, use_tqdm=True)
145+
ds.apply_field(do_nothing, field_name='x', use_tqdm=True)
146+
139147
def test_apply_cannot_modify_instance(self):
140148
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
141149
def modify_inplace(instance):

0 commit comments

Comments
 (0)