11# -*- coding: utf-8 -*-
2- from collections import defaultdict
2+ from collections import defaultdict , Iterable
33from contextlib import suppress
44from functools import partial
55from operator import itemgetter
@@ -323,8 +323,9 @@ def save(self, fname, compress=True):
323323
324324 Parameters
325325 ----------
326- fname: callable
327- Given a learner, returns a filename into which to save the data
326+ fname: callable or sequence of strings
327+ Given a learner, returns a filename into which to save the data.
328+ Or a list (or iterable) with filenames.
328329 compress : bool, default True
329330 Compress the data upon saving using `gzip`. When saving
330331 using compression, one must load it with compression too.
@@ -347,17 +348,22 @@ def save(self, fname, compress=True):
347348 >>> # Then save
348349 >>> learner.save(combo_fname) # use 'load' in the same way
349350 """
350- for l in self .learners :
351- l .save (fname (l ), compress = compress )
351+ if isinstance (fname , Iterable ):
352+ for l , _fname in zip (fname , self .learners ):
353+ l .save (_fname , compress = compress )
354+ else :
355+ for l in self .learners :
356+ l .save (fname (l ), compress = compress )
352357
353358 def load (self , fname , compress = True ):
354359 """Load the data of the child learners from pickle files
355360 in a directory.
356361
357362 Parameters
358363 ----------
359- fname: callable
360- Given a learner, returns a filename into which to save the data
364+ fname: callable or sequence of strings
365+ Given a learner, returns a filename from which to load the data.
366+ Or a list (or iterable) with filenames.
361367 compress : bool, default True
362368 If the data is compressed when saved, one must load it
363369 with compression too.
@@ -366,8 +372,12 @@ def load(self, fname, compress=True):
366372 -------
367373 See the example in the `BalancingLearner.save` doc-string.
368374 """
369- for l in self .learners :
370- l .load (fname (l ), compress = compress )
375+ if isinstance (fname , Iterable ):
376+ for l , _fname in zip (fname , self .learners ):
377+ l .load (_fname , compress = compress )
378+ else :
379+ for l in self .learners :
380+ l .load (fname (l ), compress = compress )
371381
372382 def _get_data (self ):
373383 return [l ._get_data () for l in learner .learners ]
0 commit comments