|
12 | 12 | AbsoluteStandardError, |
13 | 13 | HistoryDeviation, |
14 | 14 | MaxChecks, |
| 15 | + MaxSamples, |
15 | 16 | MaxTime, |
16 | 17 | MaxUpdates, |
17 | 18 | MinUpdates, |
@@ -349,25 +350,41 @@ def test_no_stopping_without_sampler(): |
349 | 350 | assert str(no_stop) == "NoStopping()" |
350 | 351 |
|
351 | 352 |
|
352 | | -def test_no_stopping_with_finite_sampler(): |
353 | | - class DummyFiniteSampler(IndexSampler): |
354 | | - def __init__(self, total_samples: int = 10, batch_size: int = 1): |
355 | | - super().__init__(batch_size=batch_size) |
356 | | - self.total_samples = total_samples |
| 353 | +class DummyFiniteSampler(IndexSampler): |
| 354 | + def __init__(self, total_samples: int = 10, batch_size: int = 1): |
| 355 | + super().__init__(batch_size=batch_size) |
| 356 | + self.total_samples = total_samples |
| 357 | + |
| 358 | + def sample_limit(self, indices): |
| 359 | + return self.total_samples |
| 360 | + |
| 361 | + def generate(self, indices): |
| 362 | + for i in range(self.total_samples): |
| 363 | + yield i, set() |
| 364 | + |
| 365 | + def log_weight(self, n, subset_len): |
| 366 | + return 0.0 |
| 367 | + |
| 368 | + def make_strategy(self, utility, log_coefficient=None): |
| 369 | + return None |
| 370 | + |
| 371 | + |
| 372 | +class DummyInfiniteSampler(IndexSampler): |
| 373 | + def sample_limit(self, indices): |
| 374 | + return None # Indicates an infinite sampler. |
357 | 375 |
|
358 | | - def sample_limit(self, indices): |
359 | | - return self.total_samples |
| 376 | + def generate(self, indices): |
| 377 | + while True: |
| 378 | + yield (0, set()) |
360 | 379 |
|
361 | | - def generate(self, indices): |
362 | | - for i in range(self.total_samples): |
363 | | - yield i, set() |
| 380 | + def log_weight(self, n, subset_len): |
| 381 | + return 0.0 |
364 | 382 |
|
365 | | - def log_weight(self, n, subset_len): |
366 | | - return 0.0 |
| 383 | + def make_strategy(self, utility, log_coefficient=None): |
| 384 | + return None |
367 | 385 |
|
368 | | - def make_strategy(self, utility, log_coefficient=None): |
369 | | - return None |
370 | 386 |
|
| 387 | +def test_no_stopping_with_finite_sampler(): |
371 | 388 | r = ValuationResult.from_random(5) |
372 | 389 | total_samples = 10 |
373 | 390 | batch_size = 3 |
@@ -397,28 +414,59 @@ def make_strategy(self, utility, log_coefficient=None): |
397 | 414 |
|
398 | 415 |
|
399 | 416 | def test_no_stopping_infinite_sampler(): |
400 | | - class DummyInfiniteSampler(IndexSampler): |
401 | | - def sample_limit(self, indices): |
402 | | - return None # Indicates an infinite sampler. |
403 | | - |
404 | | - def generate(self, indices): |
405 | | - while True: |
406 | | - yield (0, set()) |
407 | | - |
408 | | - def log_weight(self, n, subset_len): |
409 | | - return 0.0 |
410 | | - |
411 | | - def make_strategy(self, utility, log_coefficient=None): |
412 | | - return None |
413 | | - |
414 | 417 | sampler = DummyInfiniteSampler(batch_size=1) |
415 | 418 | no_stop = NoStopping(sampler=sampler) |
416 | 419 |
|
417 | | - batches = list(islice(sampler.generate_batches(np.array([0])), 10)) |
418 | | - assert sampler.n_samples == len(batches) |
| 420 | + _ = list(islice(sampler.generate_batches(np.array([0])), 10)) |
419 | 421 |
|
420 | 422 | # Verify that calling the criterion still returns Pending and marks no index as converged. |
421 | 423 | result = ValuationResult.from_random(5) |
422 | 424 | status = no_stop(result) |
423 | 425 | assert status == Status.Pending |
| 426 | + assert no_stop.completion() == 0.0 |
424 | 427 | np.testing.assert_equal(no_stop.converged, False) |
| 428 | + |
| 429 | + |
| 430 | +def test_max_samples_pending_and_convergence(): |
| 431 | + sampler = DummyInfiniteSampler(batch_size=1) |
| 432 | + threshold = 10 |
| 433 | + max_samples = MaxSamples(sampler, n_samples=threshold) |
| 434 | + result = ValuationResult.from_random(5) # Create a result with 5 indices |
| 435 | + |
| 436 | + status = max_samples(result) |
| 437 | + assert status == Status.Pending |
| 438 | + np.testing.assert_allclose(max_samples.completion(), 0.0) |
| 439 | + assert not max_samples.converged.all() |
| 440 | + |
| 441 | + # Set sampler.n_samples below threshold. |
| 442 | + _ = list(islice(sampler.generate_batches(np.array([0])), 5)) |
| 443 | + status = max_samples(result) |
| 444 | + assert status == Status.Pending |
| 445 | + np.testing.assert_allclose(max_samples.completion(), 5 / threshold) |
| 446 | + assert not max_samples.converged.all() |
| 447 | + |
| 448 | + # Set sampler.n_samples exactly equal to threshold. |
| 449 | + _ = list(islice(sampler.generate_batches(np.array([0])), 10)) |
| 450 | + status = max_samples(result) |
| 451 | + assert status == Status.Converged |
| 452 | + np.testing.assert_allclose(max_samples.completion(), 1.0) |
| 453 | + assert max_samples.converged.all() |
| 454 | + |
| 455 | + # Set sampler.n_samples above threshold. |
| 456 | + _ = list(islice(sampler.generate_batches(np.array([0])), 15)) |
| 457 | + status = max_samples(result) |
| 458 | + assert status == Status.Converged |
| 459 | + np.testing.assert_allclose(max_samples.completion(), 1.0) |
| 460 | + assert max_samples.converged.all() |
| 461 | + |
| 462 | + |
| 463 | +def test_max_samples_str_and_invalid(): |
| 464 | + sampler = DummyFiniteSampler(total_samples=0) |
| 465 | + max_samples = MaxSamples(sampler, 10) |
| 466 | + expected_str = f"MaxSamples({sampler.__class__.__name__}, n_samples=10)" |
| 467 | + assert str(max_samples) == expected_str |
| 468 | + |
| 469 | + with pytest.raises(ValueError): |
| 470 | + MaxSamples(sampler, 0) |
| 471 | + with pytest.raises(ValueError): |
| 472 | + MaxSamples(sampler, -5) |
0 commit comments