@@ -231,6 +231,65 @@ def default_loss(ip: LinearNDInterpolator) -> np.ndarray:
231231 return losses
232232
233233
234+ def thresholded_loss_function (
235+ lower_threshold : float | None = None ,
236+ upper_threshold : float | None = None ,
237+ priority_factor : float = 0.1 ,
238+ ) -> Callable [[LinearNDInterpolator ], np .ndarray ]:
239+ """
240+ Factory function to create a custom loss function that deprioritizes
241+ values above an upper threshold and below a lower threshold.
242+
243+ Parameters
244+ ----------
245+ lower_threshold : float, optional
246+ The lower threshold for deprioritizing values. If None (default),
247+ there is no lower threshold.
248+ upper_threshold : float, optional
249+ The upper threshold for deprioritizing values. If None (default),
250+ there is no upper threshold.
251+ priority_factor : float, default: 0.1
252+ The factor by which the loss is multiplied for values outside
253+ the specified thresholds.
254+
255+ Returns
256+ -------
257+ custom_loss : Callable[[LinearNDInterpolator], np.ndarray]
258+ A custom loss function that can be used with Learner2D.
259+ """
260+
261+ def custom_loss (ip : LinearNDInterpolator ) -> np .ndarray :
262+ """Loss function that deprioritizes values outside an upper and lower threshold.
263+
264+ Parameters
265+ ----------
266+ ip : `scipy.interpolate.LinearNDInterpolator` instance
267+
268+ Returns
269+ -------
270+ losses : numpy.ndarray
271+ Loss per triangle in ``ip.tri``.
272+ """
273+ losses = default_loss (ip )
274+
275+ if lower_threshold is not None or upper_threshold is not None :
276+ simplices = ip .tri .simplices
277+ values = ip .values [simplices ]
278+ if lower_threshold is not None :
279+ mask_lower = (values < lower_threshold ).all (axis = (1 , - 1 ))
280+ if mask_lower .any ():
281+ losses [mask_lower ] *= priority_factor
282+
283+ if upper_threshold is not None :
284+ mask_upper = (values > upper_threshold ).all (axis = (1 , - 1 ))
285+ if mask_upper .any ():
286+ losses [mask_upper ] *= priority_factor
287+
288+ return losses
289+
290+ return custom_loss
291+
292+
234293def choose_point_in_triangle (triangle : np .ndarray , max_badness : int ) -> np .ndarray :
235294 """Choose a new point in inside a triangle.
236295
0 commit comments