@@ -69,6 +69,22 @@ def _wrapped(loss_per_interval):
6969 return _wrapped
7070
7171
72+ def loss_returns (return_type , return_length ):
73+ def _wrapped (loss_per_interval ):
74+ loss_per_interval .return_type = return_type
75+ loss_per_interval .return_length = return_length
76+ return loss_per_interval
77+ return _wrapped
78+
79+
80+ def inf_format (return_type , return_len = None ):
81+ is_iterable = hasattr (return_type , '__iter__' )
82+ if is_iterable :
83+ return return_type (return_len * [np .inf ])
84+ else :
85+ return return_type (np .inf )
86+
87+
7288@uses_nth_neighbors (0 )
7389def uniform_loss (xs , ys ):
7490 """Loss function that samples the domain uniformly.
@@ -287,7 +303,8 @@ def npoints(self):
287303 def loss (self , real = True ):
288304 losses = self .losses if real else self .losses_combined
289305 if not losses :
290- return np .inf
306+ return inf_format (self .loss_per_interval .return_type ,
307+ self .loss_per_interval .return_length )
291308 max_interval , max_loss = losses .peekitem (0 )
292309 return max_loss
293310
@@ -660,8 +677,14 @@ def _set_data(self, data):
660677
661678def loss_manager (x_scale ):
662679 def sort_key (ival , loss ):
663- loss , ival = finite_loss (ival , loss , x_scale )
664- return - loss , ival
680+ if isinstance (loss , Iterable ):
681+ loss , ival = zip (* [finite_loss (ival , l , x_scale ) for l in loss ])
682+ loss = tuple (- x for x in loss )
683+ ival = ival [0 ]
684+ else :
685+ loss , ival = finite_loss (ival , loss , x_scale )
686+ loss = - loss
687+ return loss , ival
665688 sorted_dict = sortedcollections .ItemSortedDict (sort_key )
666689 return sorted_dict
667690
0 commit comments