@@ -29,15 +29,7 @@ def to_list(inp: float) -> List[float]:
2929 return [inp ]
3030
3131
32- def volume (
33- simplex : Union [
34- List [Tuple [float , float ]],
35- List [Tuple [float , float ]],
36- List [Tuple [float , float ]],
37- np .ndarray ,
38- ],
39- ys : None = None ,
40- ) -> float :
32+ def volume (simplex : List [Tuple [float , float ]], ys : None = None ,) -> float :
4133 # Notice the parameter ys is there so you can use this volume method as
4234 # as loss function
4335 matrix = np .subtract (simplex [:- 1 ], simplex [- 1 ], dtype = float )
@@ -207,13 +199,7 @@ def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values):
207199
208200
209201def choose_point_in_simplex (
210- simplex : Union [
211- List [Union [Tuple [int , int ], Tuple [float , float ]]],
212- List [Union [Tuple [float , float , float ], Tuple [int , int , int ]]],
213- List [Tuple [float , float , float ]],
214- List [Tuple [float , float ]],
215- ],
216- transform : Optional [np .ndarray ] = None ,
202+ simplex : np .ndarray , transform : Optional [np .ndarray ] = None ,
217203) -> np .ndarray :
218204 """Choose a new point in inside a simplex.
219205
@@ -318,13 +304,7 @@ class LearnerND(BaseLearner):
318304 def __init__ (
319305 self ,
320306 func : Callable ,
321- bounds : Union [
322- Tuple [Tuple [int , int ], Tuple [int , int ], Tuple [int , int ]],
323- np .ndarray ,
324- Tuple [Tuple [int , int ], Tuple [int , int ]],
325- List [Tuple [int , int ]],
326- ConvexHull ,
327- ],
307+ bounds : Union [Tuple [Tuple [float , float ], ...], ConvexHull ],
328308 loss_per_simplex : Optional [Callable ] = None ,
329309 ) -> None :
330310 self ._vdim = None
@@ -452,17 +432,7 @@ def points(self) -> np.ndarray:
452432 """Get the points from `data` as a numpy array."""
453433 return np .array (list (self .data .keys ()), dtype = float )
454434
455- def tell (
456- self ,
457- point : Union [
458- Tuple [float , float ],
459- Tuple [int , int ],
460- Tuple [int , int , int ],
461- Tuple [float , float , float ],
462- Tuple [float , float , float ],
463- ],
464- value : Union [List [int ], float , float , np .ndarray ],
465- ) -> None :
435+ def tell (self , point : Tuple [float , ...], value : Union [float , np .ndarray ],) -> None :
466436 point = tuple (point )
467437
468438 if point in self .data :
@@ -486,7 +456,7 @@ def tell(
486456 to_delete , to_add = tri .add_point (point , simplex , transform = self ._transform )
487457 self ._update_losses (to_delete , to_add )
488458
489- def _simplex_exists (self , simplex : Any ) -> bool :
459+ def _simplex_exists (self , simplex : Any ) -> bool : # XXX: specify simplex: Any
490460 simplex = tuple (sorted (simplex ))
491461 return simplex in self .tri .simplices
492462
@@ -547,9 +517,7 @@ def tell_pending(
547517 self ._update_subsimplex_losses (simpl , to_add )
548518
549519 def _try_adding_pending_point_to_simplex (
550- self ,
551- point : Union [Tuple [float , float , float ], Tuple [float , float ]],
552- simplex : Any ,
520+ self , point : Tuple [float , ...], simplex : Any , # XXX: specify simplex: Any
553521 ) -> Any :
554522 # try to insert it
555523 if not self .tri .point_in_simplex (point , simplex ):
@@ -562,7 +530,9 @@ def _try_adding_pending_point_to_simplex(
562530 self ._pending_to_simplex [point ] = simplex
563531 return self ._subtriangulations [simplex ].add_point (point )
564532
565- def _update_subsimplex_losses (self , simplex : Any , new_subsimplices : Any ) -> None :
533+ def _update_subsimplex_losses (
534+ self , simplex : Any , new_subsimplices : Any
535+ ) -> None : # XXX: specify simplex: Any
566536 loss = self ._losses [simplex ]
567537
568538 loss_density = loss / self .tri .volume (simplex )
@@ -583,14 +553,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
583553 else :
584554 return self ._ask_and_tell_pending (n )
585555
586- def _ask_bound_point (
587- self ,
588- ) -> Union [
589- Tuple [Tuple [int , int , int ], float ],
590- Tuple [Tuple [int , int ], float ],
591- Tuple [Tuple [float , float ], float ],
592- Tuple [Tuple [float , float , float ], float ],
593- ]:
556+ def _ask_bound_point (self ,) -> Tuple [Tuple [float , ...], float ]:
594557 # get the next bound point that is still available
595558 new_point = next (
596559 p
@@ -600,11 +563,7 @@ def _ask_bound_point(
600563 self .tell_pending (new_point )
601564 return new_point , np .inf
602565
603- def _ask_point_without_known_simplices (
604- self ,
605- ) -> Union [
606- Tuple [Tuple [float , float ], float ], Tuple [Tuple [float , float , float ], float ],
607- ]:
566+ def _ask_point_without_known_simplices (self ,) -> Tuple [Tuple [float , ...], float ]:
608567 assert not self ._bounds_available
609568 # pick a random point inside the bounds
610569 # XXX: change this into picking a point based on volume loss
@@ -645,11 +604,7 @@ def _pop_highest_existing_simplex(self) -> Any:
645604 " be a simplex available if LearnerND.tri() is not None."
646605 )
647606
648- def _ask_best_point (
649- self ,
650- ) -> Union [
651- Tuple [Tuple [float , float ], float ], Tuple [Tuple [float , float , float ], float ],
652- ]:
607+ def _ask_best_point (self ,) -> Tuple [Tuple [float , ...], float ]:
653608 assert self .tri is not None
654609
655610 loss , simplex , subsimplex = self ._pop_highest_existing_simplex ()
@@ -676,14 +631,7 @@ def _bounds_available(self) -> bool:
676631 for p in self ._bounds_points
677632 )
678633
679- def _ask (
680- self ,
681- ) -> Union [
682- Tuple [Tuple [int , int , int ], float ],
683- Tuple [Tuple [float , float , float ], float ],
684- Tuple [Tuple [float , float ], float ],
685- Tuple [Tuple [int , int ], float ],
686- ]:
634+ def _ask (self ,) -> Tuple [Tuple [float , ...], float ]:
687635 if self ._bounds_available :
688636 return self ._ask_bound_point () # O(1)
689637
@@ -695,7 +643,7 @@ def _ask(
695643
696644 return self ._ask_best_point () # O(log N)
697645
698- def _compute_loss (self , simplex : Any ) -> float :
646+ def _compute_loss (self , simplex : Any ) -> float : # XXX: specify simplex: Any
699647 # get the loss
700648 vertices = self .tri .get_vertices (simplex )
701649 values = [self .data [tuple (v )] for v in vertices ]
0 commit comments