55from collections import OrderedDict
66from copy import copy
77from math import sqrt
8- from typing import Callable , Iterable
8+ from typing import TYPE_CHECKING , Callable , Iterable
99
1010import cloudpickle
1111import numpy as np
2222 partial_function_from_dataframe ,
2323)
2424
25+ if TYPE_CHECKING :
26+ import holoviews
27+
2528try :
2629 import pandas
2730
@@ -40,11 +43,11 @@ def deviations(ip: LinearNDInterpolator) -> list[np.ndarray]:
4043
4144 Parameters
4245 ----------
43- ip : `scipy.interpolate.LinearNDInterpolator` instance
46+ ip
4447
4548 Returns
4649 -------
47- deviations : list
50+ deviations
4851 The deviation per triangle.
4952 """
5053 values = ip .values / (ip .values .ptp (axis = 0 ).max () or 1 )
@@ -79,11 +82,11 @@ def areas(ip: LinearNDInterpolator) -> np.ndarray:
7982
8083 Parameters
8184 ----------
82- ip : `scipy.interpolate.LinearNDInterpolator` instance
85+ ip
8386
8487 Returns
8588 -------
86- areas : numpy.ndarray
89+ areas
8790 The area per triangle in ``ip.tri``.
8891 """
8992 p = ip .tri .points [ip .tri .simplices ]
@@ -99,11 +102,11 @@ def uniform_loss(ip: LinearNDInterpolator) -> np.ndarray:
99102
100103 Parameters
101104 ----------
102- ip : `scipy.interpolate.LinearNDInterpolator` instance
105+ ip
103106
104107 Returns
105108 -------
106- losses : numpy.ndarray
109+ losses
107110 Loss per triangle in ``ip.tri``.
108111
109112 Examples
@@ -136,7 +139,7 @@ def resolution_loss_function(
136139
137140 Returns
138141 -------
139- loss_function : callable
142+ loss_function
140143
141144 Examples
142145 --------
@@ -173,11 +176,11 @@ def minimize_triangle_surface_loss(ip: LinearNDInterpolator) -> np.ndarray:
173176
174177 Parameters
175178 ----------
176- ip : `scipy.interpolate.LinearNDInterpolator` instance
179+ ip
177180
178181 Returns
179182 -------
180- losses : numpy.ndarray
183+ losses
181184 Loss per triangle in ``ip.tri``.
182185
183186 Examples
@@ -217,11 +220,11 @@ def default_loss(ip: LinearNDInterpolator) -> np.ndarray:
217220
218221 Parameters
219222 ----------
220- ip : `scipy.interpolate.LinearNDInterpolator` instance
223+ ip
221224
222225 Returns
223226 -------
224- losses : numpy.ndarray
227+ losses
225228 Loss per triangle in ``ip.tri``.
226229 """
227230 dev = np .sum (deviations (ip ), axis = 0 )
@@ -241,15 +244,15 @@ def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarr
241244
242245 Parameters
243246 ----------
244- triangle : numpy.ndarray
247+ triangle
245248 The coordinates of a triangle with shape (3, 2).
246- max_badness : int
249+ max_badness
247250 The badness at which the point is either chosen on a edge or
248251 in the middle.
249252
250253 Returns
251254 -------
252- point : numpy.ndarray
255+ point
253256 The x and y coordinate of the suggested new point.
254257 """
255258 a , b , c = triangle
@@ -267,17 +270,17 @@ def choose_point_in_triangle(triangle: np.ndarray, max_badness: int) -> np.ndarr
267270 return point
268271
269272
270- def triangle_loss (ip ) :
273+ def triangle_loss (ip : LinearNDInterpolator ) -> list [ float ] :
271274 r"""Computes the average of the volumes of the simplex combined with each
272275 neighbouring point.
273276
274277 Parameters
275278 ----------
276- ip : `scipy.interpolate.LinearNDInterpolator` instance
279+ ip
277280
278281 Returns
279282 -------
280- triangle_loss : list
283+ triangle_loss
281284 The mean volume per triangle.
282285
283286 Notes
@@ -311,13 +314,13 @@ class Learner2D(BaseLearner):
311314
312315 Parameters
313316 ----------
314- function : callable
317+ function
315318 The function to learn. Must take a tuple of two real
316319 parameters and return a real number.
317- bounds : list of 2-tuples
320+ bounds
318321 A list ``[(a1, b1), (a2, b2)]`` containing bounds,
319322 one per dimension.
320- loss_per_triangle : callable, optional
323+ loss_per_triangle
321324 A function that returns the loss for every triangle.
322325 If not provided, then a default is used, which uses
323326 the deviation from a linear estimate, as well as
@@ -424,19 +427,19 @@ def to_dataframe(
424427
425428 Parameters
426429 ----------
427- with_default_function_args : bool, optional
430+ with_default_function_args
428431 Include the ``learner.function``'s default arguments as a
429432 column, by default True
430- function_prefix : str, optional
433+ function_prefix
431434 Prefix to the ``learner.function``'s default arguments' names,
432435 by default "function."
433- seed_name : str, optional
436+ seed_name
434437 Name of the seed parameter, by default "seed"
435- x_name : str, optional
438+ x_name
436439 Name of the input x value, by default "x"
437- y_name : str, optional
440+ y_name
438441 Name of the input y value, by default "y"
439- z_name : str, optional
442+ z_name
440443 Name of the output value, by default "z"
441444
442445 Returns
@@ -475,18 +478,18 @@ def load_dataframe(
475478
476479 Parameters
477480 ----------
478- df : pandas.DataFrame
481+ df
479482 The data to load.
480- with_default_function_args : bool, optional
483+ with_default_function_args
481484 The ``with_default_function_args`` used in ``to_dataframe()``,
482485 by default True
483- function_prefix : str, optional
486+ function_prefix
484487 The ``function_prefix`` used in ``to_dataframe``, by default "function."
485- x_name : str, optional
488+ x_name
486489 The ``x_name`` used in ``to_dataframe``, by default "x"
487- y_name : str, optional
490+ y_name
488491 The ``y_name`` used in ``to_dataframe``, by default "y"
489- z_name : str, optional
492+ z_name
490493 The ``z_name`` used in ``to_dataframe``, by default "z"
491494 """
492495 data = df .set_index ([x_name , y_name ])[z_name ].to_dict ()
@@ -538,7 +541,7 @@ def interpolated_on_grid(
538541
539542 Parameters
540543 ----------
541- n : int, optional
544+ n
542545 Number of points in x and y. If None (default) this number is
543546 evaluated by looking at the size of the smallest triangle.
544547
@@ -611,14 +614,14 @@ def interpolator(self, *, scaled: bool = False) -> LinearNDInterpolator:
611614
612615 Parameters
613616 ----------
614- scaled : bool
617+ scaled
615618 Use True if all points are inside the
616619 unit-square [(-0.5, 0.5), (-0.5, 0.5)] or False if
617620 the data points are inside the ``learner.bounds``.
618621
619622 Returns
620623 -------
621- interpolator : `scipy.interpolate.LinearNDInterpolator`
624+ interpolator
622625
623626 Examples
624627 --------
@@ -755,7 +758,9 @@ def remove_unfinished(self) -> None:
755758 if p not in self .data :
756759 self ._stack [p ] = np .inf
757760
758- def plot (self , n = None , tri_alpha = 0 ):
761+ def plot (
762+ self , n : int = None , tri_alpha : float = 0
763+ ) -> holoviews .Overlay | holoviews .HoloMap :
759764 r"""Plot the Learner2D's current state.
760765
761766 This plot function interpolates the data on a regular grid.
@@ -764,16 +769,16 @@ def plot(self, n=None, tri_alpha=0):
764769
765770 Parameters
766771 ----------
767- n : int
772+ n
768773 Number of points in x and y. If None (default) this number is
769774 evaluated by looking at the size of the smallest triangle.
770- tri_alpha : float
775+ tri_alpha
771776 The opacity ``(0 <= tri_alpha <= 1)`` of the triangles overlayed
772777 on top of the image. By default the triangulation is not visible.
773778
774779 Returns
775780 -------
776- plot : `holoviews.core.Overlay` or `holoviews.core.HoloMap`
781+ plot
777782 A `holoviews.core.Overlay` of
778783 ``holoviews.Image * holoviews.EdgePaths``. If the
779784 `learner.function` returns a vector output, a
0 commit comments