@@ -47,30 +47,78 @@ def orientation(simplex):
4747 return sign
4848
4949
50- def uniform_loss (simplex , ys = None ):
50+ def uniform_loss (simplex , values , value_scale ):
51+ """
52+ Uniform loss.
53+
54+ Parameters
55+ ----------
56+ simplex : list of tuples
57+ Each entry is one point of the simplex.
58+ values : list of values
59+ The scaled function values of each of the simplex points.
60+ value_scale : float
61+ The scale of values, where ``values = function_values * value_scale``.
62+
63+ Returns
64+ -------
65+ loss : float
66+ """
5167 return volume (simplex )
5268
5369
54- def std_loss (simplex , ys ):
55- r = np .linalg .norm (np .std (ys , axis = 0 ))
70+ def std_loss (simplex , values , value_scale ):
71+ """
72+ Computes the loss of the simplex based on the standard deviation.
73+
74+ Parameters
75+ ----------
76+ simplex : list of tuples
77+ Each entry is one point of the simplex.
78+ values : list of values
79+ The scaled function values of each of the simplex points.
80+ value_scale : float
81+ The scale of values, where ``values = function_values * value_scale``.
82+
83+ Returns
84+ -------
85+ loss : float
86+ """
87+
88+ r = np .linalg .norm (np .std (values , axis = 0 ))
5689 vol = volume (simplex )
5790
5891 dim = len (simplex ) - 1
5992
6093 return r .flat * np .power (vol , 1.0 / dim ) + vol
6194
6295
63- def default_loss (simplex , ys ):
64- # return std_loss(simplex, ys)
65- if isinstance (ys [0 ], Iterable ):
66- pts = [(* x , * y ) for x , y in zip (simplex , ys )]
96+ def default_loss (simplex , values , value_scale ):
97+ """
98+ Computes the average of the volumes of the simplex.
99+
100+ Parameters
101+ ----------
102+ simplex : list of tuples
103+ Each entry is one point of the simplex.
104+ values : list of values
105+ The scaled function values of each of the simplex points.
106+ value_scale : float
107+ The scale of values, where ``values = function_values * value_scale``.
108+
109+ Returns
110+ -------
111+ loss : float
112+ """
113+ if isinstance (values [0 ], Iterable ):
114+ pts = [(* x , * y ) for x , y in zip (simplex , values )]
67115 else :
68- pts = [(* x , y ) for x , y in zip (simplex , ys )]
116+ pts = [(* x , y ) for x , y in zip (simplex , values )]
69117 return simplex_volume_in_embedding (pts )
70118
71119
72120@uses_nth_neighbors (1 )
73- def triangle_loss (simplex , values , neighbors , neighbor_values ):
121+ def triangle_loss (simplex , values , value_scale , neighbors , neighbor_values ):
74122 """
75123 Computes the average of the volumes of the simplex combined with each
76124 neighbouring point.
@@ -80,7 +128,9 @@ def triangle_loss(simplex, values, neighbors, neighbor_values):
80128 simplex : list of tuples
81129 Each entry is one point of the simplex.
82130 values : list of values
83- The function values of each of the simplex points.
131+ The scaled function values of each of the simplex points.
132+ value_scale : float
133+ The scale of values, where ``values = function_values * value_scale``.
84134 neighbors : list of tuples
85135 The neighboring points of the simplex, ordered such that simplex[0]
86136 exacly opposes neighbors[0], etc.
@@ -108,20 +158,22 @@ def triangle_loss(simplex, values, neighbors, neighbor_values):
108158def curvature_loss_function (exploration = 0.05 ):
109159 # XXX: add doc-string!
110160 @uses_nth_neighbors (1 )
111- def curvature_loss (simplex , values , neighbors , neighbor_values ):
161+ def curvature_loss (simplex , values , value_scale , neighbors , neighbor_values ):
112162 """Compute the curvature loss of a simplex.
113163
114164 Parameters
115165 ----------
116166 simplex : list of tuples
117167 Each entry is one point of the simplex.
118168 values : list of values
119- The function values of each of the simplex points.
169+ The scaled function values of each of the simplex points.
170+ value_scale : float
171+ The scale of values, where ``values = function_values * value_scale``.
120172 neighbors : list of tuples
121173 The neighboring points of the simplex, ordered such that simplex[0]
122174 exacly opposes neighbors[0], etc.
123175 neighbor_values : list of values
124- The function values for each of the neighboring points.
176+ The scaled function values for each of the neighboring points.
125177
126178 Returns
127179 -------
@@ -130,7 +182,9 @@ def curvature_loss(simplex, values, neighbors, neighbor_values):
130182 dim = len (simplex [0 ]) # the number of coordinates
131183 loss_input_volume = volume (simplex )
132184
133- loss_curvature = triangle_loss (simplex , values , neighbors , neighbor_values )
185+ loss_curvature = triangle_loss (
186+ simplex , values , value_scale , neighbors , neighbor_values
187+ )
134188 return (
135189 loss_curvature + exploration * loss_input_volume ** ((2 + dim ) / dim )
136190 ) ** (1 / (2 + dim ))
@@ -563,7 +617,9 @@ def _compute_loss(self, simplex):
563617
564618 if self .nth_neighbors == 0 :
565619 # compute the loss on the scaled simplex
566- return float (self .loss_per_simplex (vertices , values ))
620+ return float (
621+ self .loss_per_simplex (vertices , values , self ._output_multiplier )
622+ )
567623
568624 # We do need the neighbors
569625 neighbors = self .tri .get_opposing_vertices (simplex )
@@ -580,7 +636,13 @@ def _compute_loss(self, simplex):
580636 neighbor_values [i ] = self ._output_multiplier * value
581637
582638 return float (
583- self .loss_per_simplex (vertices , values , neighbor_points , neighbor_values )
639+ self .loss_per_simplex (
640+ vertices ,
641+ values ,
642+ self ._output_multiplier ,
643+ neighbor_points ,
644+ neighbor_values ,
645+ )
584646 )
585647
586648 def _update_losses (self , to_delete : set , to_add : set ):
0 commit comments