|
1 | 1 | # -*- coding: utf-8 -*- |
2 | | -from collections import defaultdict |
| 2 | +from collections import defaultdict, Iterable |
3 | 3 | from contextlib import suppress |
4 | 4 | from functools import partial |
5 | 5 | from operator import itemgetter |
@@ -317,70 +317,67 @@ def from_product(cls, f, learner_type, learner_kwargs, combos): |
317 | 317 | learners.append(learner) |
318 | 318 | return cls(learners, cdims=arguments) |
319 | 319 |
|
320 | | - def save(self, folder, compress=True): |
| 320 | + def save(self, fname, compress=True): |
321 | 321 | """Save the data of the child learners into pickle files |
322 | 322 | in a directory. |
323 | 323 |
|
324 | 324 | Parameters |
325 | 325 | ---------- |
326 | | - folder : str |
327 | | - Directory in which the learners's data will be saved. |
| 326 | + fname: callable or sequence of strings |
| 327 | + Given a learner, returns a filename into which to save the data. |
| 328 | + Or a list (or iterable) with filenames. |
328 | 329 | compress : bool, default True |
329 | 330 | Compress the data upon saving using `gzip`. When saving |
330 | 331 | using compression, one must load it with compression too. |
331 | 332 |
|
332 | | - Notes |
333 | | - ----- |
334 | | - The child learners need to have a 'fname' attribute in order to use |
335 | | - this method. |
336 | | -
|
337 | 333 | Example |
338 | 334 | ------- |
339 | | - >>> def combo_fname(val): |
340 | | - ... return '__'.join([f'{k}_{v}.p' for k, v in val.items()]) |
341 | | - ... |
342 | | - ... def f(x, a, b): return a * x**2 + b |
343 | | - ... |
344 | | - >>> learners = [] |
345 | | - >>> for combo in adaptive.utils.named_product(a=[1, 2], b=[1]): |
346 | | - ... l = Learner1D(functools.partial(f, combo=combo)) |
347 | | - ... l.fname = combo_fname(combo) # 'a_1__b_1.p', 'a_2__b_1.p' etc. |
348 | | - ... learners.append(l) |
349 | | - ... learner = BalancingLearner(learners) |
350 | | - ... # Run the learner |
351 | | - ... runner = adaptive.Runner(learner) |
352 | | - ... # Then save |
353 | | - ... learner.save('data_folder') # use 'load' in the same way |
| 335 | + >>> def combo_fname(learner): |
| 336 | + ... val = learner.function.keywords # because functools.partial |
| 337 | + ... fname = '__'.join([f'{k}_{v}.pickle' for k, v in val]) |
| 338 | + ... return 'data_folder/' + fname |
| 339 | + >>> |
| 340 | + >>> def f(x, a, b): return a * x**2 + b |
| 341 | + >>> |
| 342 | + >>> learners = [Learner1D(functools.partial(f, **combo), (-1, 1)) |
| 343 | + ... for combo in adaptive.utils.named_product(a=[1, 2], b=[1]] |
| 344 | + >>> |
| 345 | + >>> learner = BalancingLearner(learners) |
| 346 | + >>> # Run the learner |
| 347 | + >>> runner = adaptive.Runner(learner) |
| 348 | + >>> # Then save |
| 349 | + >>> learner.save(combo_fname) # use 'load' in the same way |
354 | 350 | """ |
355 | | - if len(self.learners) != len(set(l.fname for l in self.learners)): |
356 | | - raise RuntimeError("The 'learner.fname's are not all unique.") |
357 | | - |
358 | | - for l in self.learners: |
359 | | - l.save(os.path.join(folder, l.fname), compress=compress) |
| 351 | + if isinstance(fname, Iterable): |
| 352 | + for l, _fname in zip(fname, self.learners): |
| 353 | + l.save(_fname, compress=compress) |
| 354 | + else: |
| 355 | + for l in self.learners: |
| 356 | + l.save(fname(l), compress=compress) |
360 | 357 |
|
361 | | - def load(self, folder, compress=True): |
| 358 | + def load(self, fname, compress=True): |
362 | 359 | """Load the data of the child learners from pickle files |
363 | 360 | in a directory. |
364 | 361 |
|
365 | 362 | Parameters |
366 | 363 | ---------- |
367 | | - folder : str |
368 | | - Directory from which the learners's data will be loaded. |
| 364 | + fname: callable or sequence of strings |
| 365 | + Given a learner, returns a filename from which to load the data. |
| 366 | + Or a list (or iterable) with filenames. |
369 | 367 | compress : bool, default True |
370 | 368 | If the data is compressed when saved, one must load it |
371 | 369 | with compression too. |
372 | 370 |
|
373 | | - Notes |
374 | | - ----- |
375 | | - The child learners need to have a 'fname' attribute in order to use |
376 | | - this method. |
377 | | -
|
378 | 371 | Example |
379 | 372 | ------- |
380 | 373 | See the example in the `BalancingLearner.save` doc-string. |
381 | 374 | """ |
382 | | - for l in self.learners: |
383 | | - l.load(os.path.join(folder, l.fname), compress=compress) |
| 375 | + if isinstance(fname, Iterable): |
| 376 | + for l, _fname in zip(fname, self.learners): |
| 377 | + l.load(_fname, compress=compress) |
| 378 | + else: |
| 379 | + for l in self.learners: |
| 380 | + l.load(fname(l), compress=compress) |
384 | 381 |
|
385 | 382 | def _get_data(self): |
386 | 383 | return [l._get_data() for l in learner.learners] |
|
0 commit comments