@@ -484,6 +484,7 @@ def __init__(
484484 jac : bool = True ,
485485 hess : bool = False ,
486486 hessp : bool = False ,
487+ use_vectorized_jac : bool = False ,
487488 optimizer_kwargs : dict | None = None ,
488489 ):
489490 if not cast (TensorVariable , objective ).ndim == 0 :
@@ -496,6 +497,7 @@ def __init__(
496497 )
497498
498499 self .fgraph = FunctionGraph ([x , * args ], [objective ])
500+ self .use_vectorized_jac = use_vectorized_jac
499501
500502 if jac :
501503 grad_wrt_x = cast (
@@ -505,7 +507,12 @@ def __init__(
505507
506508 if hess :
507509 hess_wrt_x = cast (
508- Variable , hessian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
510+ Variable ,
511+ jacobian (
512+ self .fgraph .outputs [- 1 ],
513+ self .fgraph .inputs [0 ],
514+ vectorize = use_vectorized_jac ,
515+ ),
509516 )
510517 self .fgraph .add_output (hess_wrt_x )
511518
@@ -564,7 +571,7 @@ def L_op(self, inputs, outputs, output_grads):
564571 implicit_f ,
565572 [inner_x , * inner_args ],
566573 disconnected_inputs = "ignore" ,
567- vectorize = True ,
574+ vectorize = self . use_vectorized_jac ,
568575 )
569576 grad_wrt_args = implict_optimization_grads (
570577 df_dx = df_dx ,
@@ -584,6 +591,7 @@ def minimize(
584591 method : str = "BFGS" ,
585592 jac : bool = True ,
586593 hess : bool = False ,
594+ use_vectorized_jac : bool = False ,
587595 optimizer_kwargs : dict | None = None ,
588596) -> tuple [TensorVariable , TensorVariable ]:
589597 """
@@ -593,18 +601,21 @@ def minimize(
593601 ----------
594602 objective : TensorVariable
595603 The objective function to minimize. This should be a pytensor variable representing a scalar value.
596-
597- x : TensorVariable
604+ x: TensorVariable
598605 The variable with respect to which the objective function is minimized. It must be an input to the
599606 computational graph of `objective`.
600-
601- method : str, optional
607+ method: str, optional
602608 The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
603-
604- jac : bool, optional
605- Whether to compute and use the gradient of teh objective function with respect to x for optimization.
609+ jac: bool, optional
610+ Whether to compute and use the gradient of the objective function with respect to x for optimization.
606611 Default is True.
607-
612+ hess: bool, optional
613+ Whether to compute and use the Hessian of the objective function with respect to x for optimization.
614+ Default is False. Note that some methods require this, while others do not support it.
615+ use_vectorized_jac: bool, optional
616+ Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
617+ scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
618+ but use more memory. Default is False.
608619 optimizer_kwargs
609620 Additional keyword arguments to pass to scipy.optimize.minimize
610621
@@ -627,6 +638,7 @@ def minimize(
627638 method = method ,
628639 jac = jac ,
629640 hess = hess ,
641+ use_vectorized_jac = use_vectorized_jac ,
630642 optimizer_kwargs = optimizer_kwargs ,
631643 )
632644
@@ -807,6 +819,7 @@ def __init__(
807819 method : str = "hybr" ,
808820 jac : bool = True ,
809821 optimizer_kwargs : dict | None = None ,
822+ use_vectorized_jac : bool = False ,
810823 ):
811824 if cast (TensorVariable , variables ).ndim != cast (TensorVariable , equations ).ndim :
812825 raise ValueError (
@@ -821,7 +834,9 @@ def __init__(
821834
822835 if jac :
823836 jac_wrt_x = jacobian (
824- self .fgraph .outputs [0 ], self .fgraph .inputs [0 ], vectorize = True
837+ self .fgraph .outputs [0 ],
838+ self .fgraph .inputs [0 ],
839+ vectorize = use_vectorized_jac ,
825840 )
826841 self .fgraph .add_output (atleast_2d (jac_wrt_x ))
827842
@@ -928,6 +943,7 @@ def root(
928943 variables : TensorVariable ,
929944 method : str = "hybr" ,
930945 jac : bool = True ,
946+ use_vectorized_jac : bool = False ,
931947 optimizer_kwargs : dict | None = None ,
932948) -> tuple [TensorVariable , TensorVariable ]:
933949 """
@@ -946,6 +962,10 @@ def root(
946962 jac : bool, optional
947963 Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
948964 Default is True. Most methods require this.
965+ use_vectorized_jac: bool, optional
966+ Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
967+ This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
968+ Default is False.
949969 optimizer_kwargs : dict, optional
950970 Additional keyword arguments to pass to `scipy.optimize.root`.
951971
@@ -969,6 +989,7 @@ def root(
969989 method = method ,
970990 jac = jac ,
971991 optimizer_kwargs = optimizer_kwargs ,
992+ use_vectorized_jac = use_vectorized_jac ,
972993 )
973994
974995 solution , success = cast (
0 commit comments