|
11 | 11 | import scipy.spatial |
12 | 12 |
|
13 | 13 | from adaptive.learner.base_learner import BaseLearner |
| 14 | +from adaptive.notebook_integration import ensure_holoviews, ensure_plotly |
14 | 15 | from adaptive.learner.triangulation import ( |
15 | 16 | Triangulation, point_in_simplex, circumsphere, |
16 | | - simplex_volume_in_embedding, fast_det |
17 | | -) |
18 | | -from adaptive.notebook_integration import ensure_holoviews, ensure_plotly |
| 17 | + simplex_volume_in_embedding, fast_det) |
19 | 18 | from adaptive.utils import restore, cache_latest |
20 | 19 |
|
21 | 20 |
|
@@ -178,8 +177,14 @@ def __init__(self, func, bounds, loss_per_simplex=None): |
178 | 177 | # triangulation of the pending points inside a specific simplex |
179 | 178 | self._subtriangulations = dict() # simplex → triangulation |
180 | 179 |
|
181 | | - # scale to unit |
| 180 | + # scale to unit hypercube |
| 181 | + # for the input |
182 | 182 | self._transform = np.linalg.inv(np.diag(np.diff(self._bbox).flat)) |
| 183 | + # for the output |
| 184 | + self._min_value = None |
| 185 | + self._max_value = None |
| 186 | + self._output_multiplier = 1 # If we do not know anything, do not scale the values |
| 187 | + self._recompute_losses_factor = 1.1 |
183 | 188 |
|
184 | 189 | # create a private random number generator with fixed seed |
185 | 190 | self._random = random.Random(1) |
@@ -271,6 +276,7 @@ def tell(self, point, value): |
271 | 276 | if not self.inside_bounds(point): |
272 | 277 | return |
273 | 278 |
|
| 279 | + self._update_range(value) |
274 | 280 | if tri is not None: |
275 | 281 | simplex = self._pending_to_simplex.get(point) |
276 | 282 | if simplex is not None and not self._simplex_exists(simplex): |
@@ -338,6 +344,7 @@ def _update_subsimplex_losses(self, simplex, new_subsimplices): |
338 | 344 | subtriangulation = self._subtriangulations[simplex] |
339 | 345 | for subsimplex in new_subsimplices: |
340 | 346 | subloss = subtriangulation.volume(subsimplex) * loss_density |
| 347 | + subloss = round(subloss, ndigits=8) |
341 | 348 | heapq.heappush(self._simplex_queue, |
342 | 349 | (-subloss, simplex, subsimplex)) |
343 | 350 |
|
@@ -448,21 +455,98 @@ def update_losses(self, to_delete: set, to_add: set): |
448 | 455 | if p not in self.data) |
449 | 456 |
|
450 | 457 | for simplex in to_add: |
451 | | - vertices = self.tri.get_vertices(simplex) |
452 | | - values = [self.data[tuple(v)] for v in vertices] |
453 | | - loss = float(self.loss_per_simplex(vertices, values)) |
454 | | - self._losses[simplex] = float(loss) |
| 458 | + loss = self.compute_loss(simplex) |
| 459 | + self._losses[simplex] = loss |
455 | 460 |
|
456 | 461 | for p in pending_points_unbound: |
457 | 462 | self._try_adding_pending_point_to_simplex(p, simplex) |
458 | 463 |
|
459 | 464 | if simplex not in self._subtriangulations: |
| 465 | + loss = round(loss, ndigits=8) |
460 | 466 | heapq.heappush(self._simplex_queue, (-loss, simplex, None)) |
461 | 467 | continue |
462 | 468 |
|
463 | 469 | self._update_subsimplex_losses( |
464 | 470 | simplex, self._subtriangulations[simplex].simplices) |
465 | 471 |
|
| 472 | + def compute_loss(self, simplex): |
| 473 | + # get the loss |
| 474 | + vertices = self.tri.get_vertices(simplex) |
| 475 | + values = [self.data[tuple(v)] for v in vertices] |
| 476 | + |
| 477 | + # scale them to a cube with sides 1 |
| 478 | + vertices = vertices @ self._transform |
| 479 | + values = self._output_multiplier * values |
| 480 | + |
| 481 | + # compute the loss on the scaled simplex |
| 482 | + return float(self.loss_per_simplex(vertices, values)) |
| 483 | + |
| 484 | + def recompute_all_losses(self): |
| 485 | + """Recompute all losses and pending losses.""" |
| 486 | + # amortized O(N) complexity |
| 487 | + if self.tri is None: |
| 488 | + return |
| 489 | + |
| 490 | + # reset the _simplex_queue |
| 491 | + self._simplex_queue = [] |
| 492 | + |
| 493 | + # recompute all losses |
| 494 | + for simplex in self.tri.simplices: |
| 495 | + loss = self.compute_loss(simplex) |
| 496 | + self._losses[simplex] = loss |
| 497 | + |
| 498 | + # now distribute it around the the children if they are present |
| 499 | + if simplex not in self._subtriangulations: |
| 500 | + loss = round(loss, ndigits=8) |
| 501 | + heapq.heappush(self._simplex_queue, (-loss, simplex, None)) |
| 502 | + continue |
| 503 | + |
| 504 | + self._update_subsimplex_losses( |
| 505 | + simplex, self._subtriangulations[simplex].simplices) |
| 506 | + |
| 507 | + @property |
| 508 | + def _scale(self): |
| 509 | + # get the output scale |
| 510 | + return self._max_value - self._min_value |
| 511 | + |
| 512 | + def _update_range(self, new_output): |
| 513 | + if self._min_value is None or self._max_value is None: |
| 514 | + # this is the first point, nothing to do, just set the range |
| 515 | + self._min_value = np.array(new_output) |
| 516 | + self._max_value = np.array(new_output) |
| 517 | + self._old_scale = self._scale |
| 518 | + return False |
| 519 | + |
| 520 | + # if range in one or more directions is doubled, then update all losses |
| 521 | + self._min_value = np.minimum(self._min_value, new_output) |
| 522 | + self._max_value = np.maximum(self._max_value, new_output) |
| 523 | + |
| 524 | + scale_multiplier = 1 / self._scale |
| 525 | + if isinstance(scale_multiplier, float): |
| 526 | + scale_multiplier = np.array([scale_multiplier], dtype=float) |
| 527 | + |
| 528 | + # the maximum absolute value that is in the range. Because this is the |
| 529 | + # largest number, this also has the largest absolute numerical error. |
| 530 | + max_absolute_value_in_range = np.max(np.abs([self._min_value, self._max_value]), axis=0) |
| 531 | + # since a float has a relative error of 1e-15, the absolute error is the value * 1e-15 |
| 532 | + abs_err = 1e-15 * max_absolute_value_in_range |
| 533 | + # when scaling the floats, the error gets increased. |
| 534 | + scaled_err = abs_err * scale_multiplier |
| 535 | + |
| 536 | + allowed_numerical_error = 1e-2 |
| 537 | + |
| 538 | + # do not scale along the axis if the numerical error gets too big |
| 539 | + scale_multiplier[scaled_err > allowed_numerical_error] = 1 |
| 540 | + |
| 541 | + self._output_multiplier = scale_multiplier |
| 542 | + |
| 543 | + scale_factor = np.max(np.nan_to_num(self._scale / self._old_scale)) |
| 544 | + if scale_factor > self._recompute_losses_factor: |
| 545 | + self._old_scale = self._scale |
| 546 | + self.recompute_all_losses() |
| 547 | + return True |
| 548 | + return False |
| 549 | + |
466 | 550 | def losses(self): |
467 | 551 | """Get the losses of each simplex in the current triangulation, as dict |
468 | 552 |
|
|
0 commit comments