@@ -29,15 +29,7 @@ def to_list(inp: float) -> List[float]:
2929 return [inp ]
3030
3131
32- def volume (
33- simplex : Union [
34- List [Tuple [float , float ]],
35- List [Tuple [float , float ]],
36- List [Tuple [float , float ]],
37- np .ndarray ,
38- ],
39- ys : None = None ,
40- ) -> float :
32+ def volume (simplex : List [Tuple [float , float ]], ys : None = None ,) -> float :
4133 # Notice the parameter ys is there so you can use this volume method as
4234 # as loss function
4335 matrix = np .subtract (simplex [:- 1 ], simplex [- 1 ], dtype = float )
@@ -207,13 +199,7 @@ def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values):
207199
208200
209201def choose_point_in_simplex (
210- simplex : Union [
211- List [Union [Tuple [int , int ], Tuple [float , float ]]],
212- List [Union [Tuple [float , float , float ], Tuple [int , int , int ]]],
213- List [Tuple [float , float , float ]],
214- List [Tuple [float , float ]],
215- ],
216- transform : Optional [np .ndarray ] = None ,
202+ simplex : np .ndarray , transform : Optional [np .ndarray ] = None ,
217203) -> np .ndarray :
218204 """Choose a new point in inside a simplex.
219205
@@ -318,13 +304,7 @@ class LearnerND(BaseLearner):
318304 def __init__ (
319305 self ,
320306 func : Callable ,
321- bounds : Union [
322- Tuple [Tuple [int , int ], Tuple [int , int ], Tuple [int , int ]],
323- np .ndarray ,
324- Tuple [Tuple [int , int ], Tuple [int , int ]],
325- List [Tuple [int , int ]],
326- ConvexHull ,
327- ],
307+ bounds : Union [Tuple [Tuple [float , float ], ...], ConvexHull ],
328308 loss_per_simplex : Optional [Callable ] = None ,
329309 ) -> None :
330310 self ._vdim = None
@@ -452,17 +432,7 @@ def points(self) -> np.ndarray:
452432 """Get the points from `data` as a numpy array."""
453433 return np .array (list (self .data .keys ()), dtype = float )
454434
455- def tell (
456- self ,
457- point : Union [
458- Tuple [float , float ],
459- Tuple [int , int ],
460- Tuple [int , int , int ],
461- Tuple [float , float , float ],
462- Tuple [float , float , float ],
463- ],
464- value : Union [List [int ], float , float , np .ndarray ],
465- ) -> None :
435+ def tell (self , point : Tuple [float , ...], value : Union [float , np .ndarray ],) -> None :
466436 point = tuple (point )
467437
468438 if point in self .data :
@@ -486,20 +456,11 @@ def tell(
486456 to_delete , to_add = tri .add_point (point , simplex , transform = self ._transform )
487457 self ._update_losses (to_delete , to_add )
488458
489- def _simplex_exists (self , simplex : Any ) -> bool :
459+ def _simplex_exists (self , simplex : Any ) -> bool : # XXX: specify simplex: Any
490460 simplex = tuple (sorted (simplex ))
491461 return simplex in self .tri .simplices
492462
493- def inside_bounds (
494- self ,
495- point : Union [
496- Tuple [float , float ],
497- Tuple [float , float , float ],
498- Tuple [int , int , int ],
499- Tuple [int , int ],
500- Tuple [float , float , float ],
501- ],
502- ) -> Union [bool , np .bool_ ]:
463+ def inside_bounds (self , point : Tuple [float , ...],) -> Union [bool , np .bool_ ]:
503464 """Check whether a point is inside the bounds."""
504465 if hasattr (self , "_interior" ):
505466 return self ._interior .find_simplex (point , tol = 1e-8 ) >= 0
@@ -509,17 +470,7 @@ def inside_bounds(
509470 (mn - eps ) <= p <= (mx + eps ) for p , (mn , mx ) in zip (point , self ._bbox )
510471 )
511472
512- def tell_pending (
513- self ,
514- point : Union [
515- Tuple [int , int ],
516- Tuple [float , float , float ],
517- Tuple [float , float ],
518- Tuple [int , int , int ],
519- ],
520- * ,
521- simplex = None ,
522- ) -> None :
473+ def tell_pending (self , point : Tuple [float , ...], * , simplex = None ,) -> None :
523474 point = tuple (point )
524475 if not self .inside_bounds (point ):
525476 return
@@ -547,9 +498,7 @@ def tell_pending(
547498 self ._update_subsimplex_losses (simpl , to_add )
548499
549500 def _try_adding_pending_point_to_simplex (
550- self ,
551- point : Union [Tuple [float , float , float ], Tuple [float , float ]],
552- simplex : Any ,
501+ self , point : Tuple [float , ...], simplex : Any , # XXX: specify simplex: Any
553502 ) -> Any :
554503 # try to insert it
555504 if not self .tri .point_in_simplex (point , simplex ):
@@ -562,7 +511,9 @@ def _try_adding_pending_point_to_simplex(
562511 self ._pending_to_simplex [point ] = simplex
563512 return self ._subtriangulations [simplex ].add_point (point )
564513
565- def _update_subsimplex_losses (self , simplex : Any , new_subsimplices : Any ) -> None :
514+ def _update_subsimplex_losses (
515+ self , simplex : Any , new_subsimplices : Any
516+ ) -> None : # XXX: specify simplex: Any
566517 loss = self ._losses [simplex ]
567518
568519 loss_density = loss / self .tri .volume (simplex )
@@ -583,14 +534,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
583534 else :
584535 return self ._ask_and_tell_pending (n )
585536
586- def _ask_bound_point (
587- self ,
588- ) -> Union [
589- Tuple [Tuple [int , int , int ], float ],
590- Tuple [Tuple [int , int ], float ],
591- Tuple [Tuple [float , float ], float ],
592- Tuple [Tuple [float , float , float ], float ],
593- ]:
537+ def _ask_bound_point (self ,) -> Tuple [Tuple [float , ...], float ]:
594538 # get the next bound point that is still available
595539 new_point = next (
596540 p
@@ -600,11 +544,7 @@ def _ask_bound_point(
600544 self .tell_pending (new_point )
601545 return new_point , np .inf
602546
603- def _ask_point_without_known_simplices (
604- self ,
605- ) -> Union [
606- Tuple [Tuple [float , float ], float ], Tuple [Tuple [float , float , float ], float ],
607- ]:
547+ def _ask_point_without_known_simplices (self ,) -> Tuple [Tuple [float , ...], float ]:
608548 assert not self ._bounds_available
609549 # pick a random point inside the bounds
610550 # XXX: change this into picking a point based on volume loss
@@ -645,11 +585,7 @@ def _pop_highest_existing_simplex(self) -> Any:
645585 " be a simplex available if LearnerND.tri() is not None."
646586 )
647587
648- def _ask_best_point (
649- self ,
650- ) -> Union [
651- Tuple [Tuple [float , float ], float ], Tuple [Tuple [float , float , float ], float ],
652- ]:
588+ def _ask_best_point (self ,) -> Tuple [Tuple [float , ...], float ]:
653589 assert self .tri is not None
654590
655591 loss , simplex , subsimplex = self ._pop_highest_existing_simplex ()
@@ -676,14 +612,7 @@ def _bounds_available(self) -> bool:
676612 for p in self ._bounds_points
677613 )
678614
679- def _ask (
680- self ,
681- ) -> Union [
682- Tuple [Tuple [int , int , int ], float ],
683- Tuple [Tuple [float , float , float ], float ],
684- Tuple [Tuple [float , float ], float ],
685- Tuple [Tuple [int , int ], float ],
686- ]:
615+ def _ask (self ,) -> Tuple [Tuple [float , ...], float ]:
687616 if self ._bounds_available :
688617 return self ._ask_bound_point () # O(1)
689618
@@ -695,7 +624,7 @@ def _ask(
695624
696625 return self ._ask_best_point () # O(log N)
697626
698- def _compute_loss (self , simplex : Any ) -> float :
627+ def _compute_loss (self , simplex : Any ) -> float : # XXX: specify simplex: Any
699628 # get the loss
700629 vertices = self .tri .get_vertices (simplex )
701630 values = [self .data [tuple (v )] for v in vertices ]
0 commit comments