@@ -69,6 +69,22 @@ def _wrapped(loss_per_interval):
6969 return _wrapped
7070
7171
72+ def loss_returns (return_type , return_length ):
73+ def _wrapped (loss_per_interval ):
74+ loss_per_interval .return_type = return_type
75+ loss_per_interval .return_length = return_length
76+ return loss_per_interval
77+ return _wrapped
78+
79+
80+ def inf_format (return_type , return_len = None ):
81+ is_iterable = hasattr (return_type , '__iter__' )
82+ if is_iterable :
83+ return return_type (return_len * [np .inf ])
84+ else :
85+ return return_type (np .inf )
86+
87+
7288@uses_nth_neighbors (0 )
7389def uniform_loss (xs , ys ):
7490 """Loss function that samples the domain uniformly.
@@ -287,7 +303,8 @@ def npoints(self):
287303 def loss (self , real = True ):
288304 losses = self .losses if real else self .losses_combined
289305 if not losses :
290- return np .inf
306+ return inf_format (self .loss_per_interval .return_type ,
307+ self .loss_per_interval .return_length )
291308 max_interval , max_loss = losses .peekitem (0 )
292309 return max_loss
293310
@@ -661,7 +678,8 @@ def _set_data(self, data):
661678def loss_manager (x_scale ):
662679 def sort_key (ival , loss ):
663680 loss , ival = finite_loss (ival , loss , x_scale )
664- return - loss , ival
681+ loss = tuple (- l for l in loss ) if isinstance (loss , tuple ) else - loss
682+ return loss , ival
665683 sorted_dict = sortedcollections .ItemSortedDict (sort_key )
666684 return sorted_dict
667685
0 commit comments