Skip to content

Commit af6752e

Browse files
authored
Merge pull request #661 from aai-institute/feature/max-samples
Feature/max samples
2 parents 3a56a1d + 52c23f6 commit af6752e

File tree

3 files changed

+126
-32
lines changed

3 files changed

+126
-32
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
### Added
77

8+
- New stopping criterion `MaxSamples`
9+
[PR #661](https://github.com/aai-institute/pyDVL/pull/661)
810
- Introduced `UtilityModel` and two implementations `IndicatorUtilityModel`
911
and `DeepSetsUtilityModel` for data utility learning
1012
[PR #650](https://github.com/aai-institute/pyDVL/pull/650)

src/pydvl/valuation/stopping.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,10 @@
187187
"AbsoluteStandardError",
188188
"HistoryDeviation",
189189
"MaxChecks",
190+
"MaxSamples",
191+
"MaxTime",
190192
"MaxUpdates",
191193
"MinUpdates",
192-
"MaxTime",
193194
"NoStopping",
194195
"RankCorrelation",
195196
"StoppingCriterion",
@@ -614,9 +615,52 @@ def completion(self) -> float:
614615
return 0.0
615616

616617
def __str__(self) -> str:
618+
if self.sampler is not None:
619+
return f"NoStopping({self.sampler.__class__.__name__})"
617620
return "NoStopping()"
618621

619622

623+
class MaxSamples(StoppingCriterion):
624+
"""Run until the sampler has sampled the given number of samples.
625+
626+
!!! warning
627+
If the sampler is batched, and the valuation method runs in parallel, the check
628+
might be off by the sampler's batch size.
629+
630+
Args:
631+
sampler: The sampler to check.
632+
n_samples: The number of samples to run until.
633+
modify_result: If `True` the status of the input
634+
[ValuationResult][pydvl.valuation.result.ValuationResult] is modified in
635+
place after the call.
636+
"""
637+
638+
def __init__(
639+
self, sampler: IndexSampler, n_samples: int, modify_result: bool = True
640+
):
641+
if n_samples <= 0:
642+
raise ValueError("n_samples must be positive")
643+
super().__init__(modify_result=modify_result)
644+
self.sampler = sampler
645+
self.n_samples = n_samples
646+
self._completion = 0.0
647+
648+
def _check(self, result: ValuationResult) -> Status:
649+
self._completion = np.clip(self.sampler.n_samples / self.n_samples, 0.0, 1.0)
650+
if self.sampler.n_samples >= self.n_samples:
651+
self._converged = np.full_like(result.indices, True, dtype=bool)
652+
return Status.Converged
653+
return Status.Pending
654+
655+
def completion(self) -> float:
656+
return self._completion
657+
658+
def __str__(self) -> str:
659+
return (
660+
f"MaxSamples({self.sampler.__class__.__name__}, n_samples={self.n_samples})"
661+
)
662+
663+
620664
class MinUpdates(StoppingCriterion):
621665
"""Terminate as soon as all value updates exceed or equal the given threshold.
622666
@@ -965,4 +1009,4 @@ def reset(self) -> Self:
9651009
return super().reset()
9661010

9671011
def __str__(self):
968-
return f"RankCorrelation({self.rtol=}, {self.burn_in=}, {self.fraction=})"
1012+
return f"RankCorrelation(rtol={self.rtol}, burn_in={self.burn_in}, fraction={self.fraction})"

tests/valuation/test_stopping.py

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
AbsoluteStandardError,
1313
HistoryDeviation,
1414
MaxChecks,
15+
MaxSamples,
1516
MaxTime,
1617
MaxUpdates,
1718
MinUpdates,
@@ -349,25 +350,41 @@ def test_no_stopping_without_sampler():
349350
assert str(no_stop) == "NoStopping()"
350351

351352

352-
def test_no_stopping_with_finite_sampler():
353-
class DummyFiniteSampler(IndexSampler):
354-
def __init__(self, total_samples: int = 10, batch_size: int = 1):
355-
super().__init__(batch_size=batch_size)
356-
self.total_samples = total_samples
353+
class DummyFiniteSampler(IndexSampler):
354+
def __init__(self, total_samples: int = 10, batch_size: int = 1):
355+
super().__init__(batch_size=batch_size)
356+
self.total_samples = total_samples
357+
358+
def sample_limit(self, indices):
359+
return self.total_samples
360+
361+
def generate(self, indices):
362+
for i in range(self.total_samples):
363+
yield i, set()
364+
365+
def log_weight(self, n, subset_len):
366+
return 0.0
367+
368+
def make_strategy(self, utility, log_coefficient=None):
369+
return None
370+
371+
372+
class DummyInfiniteSampler(IndexSampler):
373+
def sample_limit(self, indices):
374+
return None # Indicates an infinite sampler.
357375

358-
def sample_limit(self, indices):
359-
return self.total_samples
376+
def generate(self, indices):
377+
while True:
378+
yield (0, set())
360379

361-
def generate(self, indices):
362-
for i in range(self.total_samples):
363-
yield i, set()
380+
def log_weight(self, n, subset_len):
381+
return 0.0
364382

365-
def log_weight(self, n, subset_len):
366-
return 0.0
383+
def make_strategy(self, utility, log_coefficient=None):
384+
return None
367385

368-
def make_strategy(self, utility, log_coefficient=None):
369-
return None
370386

387+
def test_no_stopping_with_finite_sampler():
371388
r = ValuationResult.from_random(5)
372389
total_samples = 10
373390
batch_size = 3
@@ -397,28 +414,59 @@ def make_strategy(self, utility, log_coefficient=None):
397414

398415

399416
def test_no_stopping_infinite_sampler():
400-
class DummyInfiniteSampler(IndexSampler):
401-
def sample_limit(self, indices):
402-
return None # Indicates an infinite sampler.
403-
404-
def generate(self, indices):
405-
while True:
406-
yield (0, set())
407-
408-
def log_weight(self, n, subset_len):
409-
return 0.0
410-
411-
def make_strategy(self, utility, log_coefficient=None):
412-
return None
413-
414417
sampler = DummyInfiniteSampler(batch_size=1)
415418
no_stop = NoStopping(sampler=sampler)
416419

417-
batches = list(islice(sampler.generate_batches(np.array([0])), 10))
418-
assert sampler.n_samples == len(batches)
420+
_ = list(islice(sampler.generate_batches(np.array([0])), 10))
419421

420422
# Verify that calling the criterion still returns Pending and marks no index as converged.
421423
result = ValuationResult.from_random(5)
422424
status = no_stop(result)
423425
assert status == Status.Pending
426+
assert no_stop.completion() == 0.0
424427
np.testing.assert_equal(no_stop.converged, False)
428+
429+
430+
def test_max_samples_pending_and_convergence():
431+
sampler = DummyInfiniteSampler(batch_size=1)
432+
threshold = 10
433+
max_samples = MaxSamples(sampler, n_samples=threshold)
434+
result = ValuationResult.from_random(5) # Create a result with 5 indices
435+
436+
status = max_samples(result)
437+
assert status == Status.Pending
438+
np.testing.assert_allclose(max_samples.completion(), 0.0)
439+
assert not max_samples.converged.all()
440+
441+
# Set sampler.n_samples below threshold.
442+
_ = list(islice(sampler.generate_batches(np.array([0])), 5))
443+
status = max_samples(result)
444+
assert status == Status.Pending
445+
np.testing.assert_allclose(max_samples.completion(), 5 / threshold)
446+
assert not max_samples.converged.all()
447+
448+
# Set sampler.n_samples exactly equal to threshold.
449+
_ = list(islice(sampler.generate_batches(np.array([0])), 10))
450+
status = max_samples(result)
451+
assert status == Status.Converged
452+
np.testing.assert_allclose(max_samples.completion(), 1.0)
453+
assert max_samples.converged.all()
454+
455+
# Set sampler.n_samples above threshold.
456+
_ = list(islice(sampler.generate_batches(np.array([0])), 15))
457+
status = max_samples(result)
458+
assert status == Status.Converged
459+
np.testing.assert_allclose(max_samples.completion(), 1.0)
460+
assert max_samples.converged.all()
461+
462+
463+
def test_max_samples_str_and_invalid():
464+
sampler = DummyFiniteSampler(total_samples=0)
465+
max_samples = MaxSamples(sampler, 10)
466+
expected_str = f"MaxSamples({sampler.__class__.__name__}, n_samples=10)"
467+
assert str(max_samples) == expected_str
468+
469+
with pytest.raises(ValueError):
470+
MaxSamples(sampler, 0)
471+
with pytest.raises(ValueError):
472+
MaxSamples(sampler, -5)

0 commit comments

Comments
 (0)