44from typing import cast
55
66import numpy as np
7- from scipy .optimize import minimize as scipy_minimize
8- from scipy .optimize import minimize_scalar as scipy_minimize_scalar
9- from scipy .optimize import root as scipy_root
10- from scipy .optimize import root_scalar as scipy_root_scalar
117
128import pytensor .scalar as ps
13- from pytensor import Variable , function , graph_replace
9+ from pytensor . compile . function import function
1410from pytensor .gradient import grad , hessian , jacobian
1511from pytensor .graph import Apply , Constant , FunctionGraph
1612from pytensor .graph .basic import ancestors , truncated_graph_inputs
1713from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
14+ from pytensor .graph .replace import graph_replace
1815from pytensor .tensor .basic import (
1916 atleast_2d ,
2017 concatenate ,
2421)
2522from pytensor .tensor .math import dot
2623from pytensor .tensor .slinalg import solve
27- from pytensor .tensor .variable import TensorVariable
24+ from pytensor .tensor .variable import TensorVariable , Variable
25+
26+
27+ # scipy.optimize can be slow to import, and will not be used by most users
28+ # We import scipy.optimize lazily inside optimization perform methods to avoid this.
29+ optimize = None
2830
2931
3032_log = logging .getLogger (__name__ )
@@ -352,8 +354,6 @@ def implict_optimization_grads(
352354
353355
354356class MinimizeScalarOp (ScipyScalarWrapperOp ):
355- __props__ = ("method" ,)
356-
357357 def __init__ (
358358 self ,
359359 x : Variable ,
@@ -377,15 +377,22 @@ def __init__(
377377 self ._fn = None
378378 self ._fn_wrapped = None
379379
380+ def __str__ (self ):
381+ return f"{ self .__class__ .__name__ } (method={ self .method } )"
382+
380383 def perform (self , node , inputs , outputs ):
384+ global optimize
385+ if optimize is None :
386+ import scipy .optimize as optimize
387+
381388 f = self .fn_wrapped
382389 f .clear_cache ()
383390
384391 # minimize_scalar doesn't take x0 as an argument. The Op still needs this input (to symbolically determine
385392 # the args of the objective function), but it is not used in the optimization.
386393 x0 , * args = inputs
387394
388- res = scipy_minimize_scalar (
395+ res = optimize . minimize_scalar (
389396 fun = f .value ,
390397 args = tuple (args ),
391398 method = self .method ,
@@ -426,6 +433,27 @@ def minimize_scalar(
426433):
427434 """
428435 Minimize a scalar objective function using scipy.optimize.minimize_scalar.
436+
437+ Parameters
438+ ----------
439+ objective : TensorVariable
440+ The objective function to minimize. This should be a PyTensor variable representing a scalar value.
441+ x : TensorVariable
442+ The variable with respect to which the objective function is minimized. It must be a scalar and an
443+ input to the computational graph of `objective`.
444+ method : str, optional
445+ The optimization method to use. Default is "brent". See `scipy.optimize.minimize_scalar` for other options.
446+ optimizer_kwargs : dict, optional
447+ Additional keyword arguments to pass to `scipy.optimize.minimize_scalar`.
448+
449+ Returns
450+ -------
451+ solution: TensorVariable
452+ Value of `x` that minimizes `objective(x, *args)`. If the success flag is False, this will be the
453+ final state returned by the minimization routine, not necessarily a minimum.
454+ success : TensorVariable
455+ Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
456+ value, based on the requested convergence criteria.
429457 """
430458
431459 args = _find_optimization_parameters (objective , x )
@@ -438,12 +466,14 @@ def minimize_scalar(
438466 optimizer_kwargs = optimizer_kwargs ,
439467 )
440468
441- return minimize_scalar_op (x , * args )
469+ solution , success = cast (
470+ tuple [TensorVariable , TensorVariable ], minimize_scalar_op (x , * args )
471+ )
442472
473+ return solution , success
443474
444- class MinimizeOp (ScipyWrapperOp ):
445- __props__ = ("method" , "jac" , "hess" , "hessp" )
446475
476+ class MinimizeOp (ScipyWrapperOp ):
447477 def __init__ (
448478 self ,
449479 x : Variable ,
@@ -487,11 +517,24 @@ def __init__(
487517 self ._fn = None
488518 self ._fn_wrapped = None
489519
520+ def __str__ (self ):
521+ str_args = ", " .join (
522+ [
523+ f"{ arg } ={ getattr (self , arg )} "
524+ for arg in ["method" , "jac" , "hess" , "hessp" ]
525+ ]
526+ )
527+ return f"{ self .__class__ .__name__ } ({ str_args } )"
528+
490529 def perform (self , node , inputs , outputs ):
530+ global optimize
531+ if optimize is None :
532+ import scipy .optimize as optimize
533+
491534 f = self .fn_wrapped
492535 x0 , * args = inputs
493536
494- res = scipy_minimize (
537+ res = optimize . minimize (
495538 fun = f .value_and_grad if self .jac else f .value ,
496539 jac = self .jac ,
497540 x0 = x0 ,
@@ -538,7 +581,7 @@ def minimize(
538581 jac : bool = True ,
539582 hess : bool = False ,
540583 optimizer_kwargs : dict | None = None ,
541- ):
584+ ) -> tuple [ TensorVariable , TensorVariable ] :
542585 """
543586 Minimize a scalar objective function using scipy.optimize.minimize.
544587
@@ -563,9 +606,13 @@ def minimize(
563606
564607 Returns
565608 -------
566- TensorVariable
567- The optimized value of x that minimizes the objective function.
609+ solution: TensorVariable
610+ The optimized value of the vector of inputs `x` that minimizes `objective(x, *args)`. If the success flag
611+ is False, this will be the final state of the minimization routine, but not necessarily a minimum.
568612
613+ success: TensorVariable
614+ Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
615+ value, based on the requested convergence criteria.
569616 """
570617 args = _find_optimization_parameters (objective , x )
571618
@@ -579,12 +626,14 @@ def minimize(
579626 optimizer_kwargs = optimizer_kwargs ,
580627 )
581628
582- return minimize_op (x , * args )
629+ solution , success = cast (
630+ tuple [TensorVariable , TensorVariable ], minimize_op (x , * args )
631+ )
632+
633+ return solution , success
583634
584635
585636class RootScalarOp (ScipyScalarWrapperOp ):
586- __props__ = ("method" , "jac" , "hess" )
587-
588637 def __init__ (
589638 self ,
590639 variables ,
@@ -633,14 +682,24 @@ def __init__(
633682 self ._fn = None
634683 self ._fn_wrapped = None
635684
685+ def __str__ (self ):
686+ str_args = ", " .join (
687+ [f"{ arg } ={ getattr (self , arg )} " for arg in ["method" , "jac" , "hess" ]]
688+ )
689+ return f"{ self .__class__ .__name__ } ({ str_args } )"
690+
636691 def perform (self , node , inputs , outputs ):
692+ global optimize
693+ if optimize is None :
694+ import scipy .optimize as optimize
695+
637696 f = self .fn_wrapped
638697 f .clear_cache ()
639698 # f.copy_x = True
640699
641700 variables , * args = inputs
642701
643- res = scipy_root_scalar (
702+ res = optimize . root_scalar (
644703 f = f .value ,
645704 fprime = f .grad if self .jac else None ,
646705 fprime2 = f .hess if self .hess else None ,
@@ -676,19 +735,48 @@ def L_op(self, inputs, outputs, output_grads):
676735
677736def root_scalar (
678737 equation : TensorVariable ,
679- variables : TensorVariable ,
738+ variable : TensorVariable ,
680739 method : str = "secant" ,
681740 jac : bool = False ,
682741 hess : bool = False ,
683742 optimizer_kwargs : dict | None = None ,
684- ):
743+ ) -> tuple [ TensorVariable , TensorVariable ] :
685744 """
686745 Find roots of a scalar equation using scipy.optimize.root_scalar.
746+
747+ Parameters
748+ ----------
749+ equation : TensorVariable
750+ The equation for which to find roots. This should be a PyTensor variable representing a single equation in one
751+ variable. The function will find `variables` such that `equation(variables, *args) = 0`.
752+ variable : TensorVariable
753+ The variable with respect to which the equation is solved. It must be a scalar and an input to the
754+ computational graph of `equation`.
755+ method : str, optional
756+ The root-finding method to use. Default is "secant". See `scipy.optimize.root_scalar` for other options.
757+ jac : bool, optional
758+ Whether to compute and use the first derivative of the equation with respect to `variables`.
759+ Default is False. Some methods require this.
760+ hess : bool, optional
761+ Whether to compute and use the second derivative of the equation with respect to `variables`.
762+ Default is False. Some methods require this.
763+ optimizer_kwargs : dict, optional
764+ Additional keyword arguments to pass to `scipy.optimize.root_scalar`.
765+
766+ Returns
767+ -------
768+ solution: TensorVariable
769+ The final state of the root-finding routine. When `success` is True, this is the value of `variables` that
770+ causes `equation` to evaluate to zero. Otherwise it is the final state returned by the root-finding
771+ routine, but not necessarily a root.
772+
773+ success: TensorVariable
774+ Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
687775 """
688- args = _find_optimization_parameters (equation , variables )
776+ args = _find_optimization_parameters (equation , variable )
689777
690778 root_scalar_op = RootScalarOp (
691- variables ,
779+ variable ,
692780 * args ,
693781 equation = equation ,
694782 method = method ,
@@ -697,7 +785,11 @@ def root_scalar(
697785 optimizer_kwargs = optimizer_kwargs ,
698786 )
699787
700- return root_scalar_op (variables , * args )
788+ solution , success = cast (
789+ tuple [TensorVariable , TensorVariable ], root_scalar_op (variable , * args )
790+ )
791+
792+ return solution , success
701793
702794
703795class RootOp (ScipyWrapperOp ):
@@ -734,6 +826,12 @@ def __init__(
734826 self ._fn = None
735827 self ._fn_wrapped = None
736828
829+ def __str__ (self ):
830+ str_args = ", " .join (
831+ [f"{ arg } ={ getattr (self , arg )} " for arg in ["method" , "jac" ]]
832+ )
833+ return f"{ self .__class__ .__name__ } ({ str_args } )"
834+
737835 def build_fn (self ):
738836 outputs = self .inner_outputs
739837 variables , * args = self .inner_inputs
@@ -761,13 +859,17 @@ def build_fn(self):
761859 self ._fn_wrapped = LRUCache1 (fn )
762860
763861 def perform (self , node , inputs , outputs ):
862+ global optimize
863+ if optimize is None :
864+ import scipy .optimize as optimize
865+
764866 f = self .fn_wrapped
765867 f .clear_cache ()
766868 f .copy_x = True
767869
768870 variables , * args = inputs
769871
770- res = scipy_root (
872+ res = optimize . root (
771873 fun = f ,
772874 jac = self .jac ,
773875 x0 = variables ,
@@ -815,8 +917,36 @@ def root(
815917 method : str = "hybr" ,
816918 jac : bool = True ,
817919 optimizer_kwargs : dict | None = None ,
818- ):
819- """Find roots of a system of equations using scipy.optimize.root."""
920+ ) -> tuple [TensorVariable , TensorVariable ]:
921+ """
922+ Find roots of a system of equations using scipy.optimize.root.
923+
924+ Parameters
925+ ----------
926+ equations : TensorVariable
927+ The system of equations for which to find roots. This should be a PyTensor variable representing a
928+ vector (or scalar) value. The function will find `variables` such that `equations(variables, *args) = 0`.
929+ variables : TensorVariable
930+ The variable(s) with respect to which the system of equations is solved. It must be an input to the
931+ computational graph of `equations` and have the same number of dimensions as `equations`.
932+ method : str, optional
933+ The root-finding method to use. Default is "hybr". See `scipy.optimize.root` for other options.
934+ jac : bool, optional
935+ Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
936+ Default is True. Most methods require this.
937+ optimizer_kwargs : dict, optional
938+ Additional keyword arguments to pass to `scipy.optimize.root`.
939+
940+ Returns
941+ -------
942+ solution: TensorVariable
943+ The final state of the root-finding routine. When `success` is True, this is the value of `variables` that
944+ causes all `equations` to evaluate to zero. Otherwise it is the final state returned by the root-finding
945+ routine, but not necessarily a root.
946+
947+ success: TensorVariable
948+ Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
949+ """
820950
821951 args = _find_optimization_parameters (equations , variables )
822952
@@ -829,7 +959,11 @@ def root(
829959 optimizer_kwargs = optimizer_kwargs ,
830960 )
831961
832- return root_op (variables , * args )
962+ solution , success = cast (
963+ tuple [TensorVariable , TensorVariable ], root_op (variables , * args )
964+ )
965+
966+ return solution , success
833967
834968
835969__all__ = ["minimize_scalar" , "minimize" , "root_scalar" , "root" ]
0 commit comments