@@ -395,6 +395,12 @@ def line_search(
395395 hager_zhang_initial_found = _Hager_Zhang_Initial_State .NOT_FOUND
396396 hager_zhang_descent_grad = wolfe_descent_new_grad
397397 hager_zhang_state = _Hager_Zhang_State .NONE
398+ hager_zhang_eps = (
399+ jnp .linalg .norm (ravel_pytree (gradient )[0 ])
400+ * varipeps_config .line_search_hager_zhang_eps_grad_norm_factor
401+ if varipeps_config .line_search_hager_zhang_eps_use_grad_norm
402+ else varipeps_config .line_search_hager_zhang_eps
403+ )
398404
399405 new_value = current_value
400406
@@ -608,7 +614,7 @@ def line_search(
608614
609615 if descent_new_grad >= hz_wolfe_2_right :
610616 if hz_wolfe_1_left >= hz_wolfe_1_right and new_value <= (
611- current_value + varipeps_config . line_search_hager_zhang_eps
617+ current_value + hager_zhang_eps
612618 ):
613619 break
614620
@@ -617,7 +623,7 @@ def line_search(
617623 ) * hager_zhang_descent_grad
618624
619625 if hz_approx_wolfe_left >= hager_zhang_descent_grad and new_value <= (
620- current_value + varipeps_config . line_search_hager_zhang_eps
626+ current_value + hager_zhang_eps
621627 ):
622628 break
623629
@@ -635,9 +641,7 @@ def line_search(
635641 hager_zhang_upper_bound_grad = new_gradient
636642 hager_zhang_upper_bound_des_grad = descent_new_grad
637643 hager_zhang_initial_found = _Hager_Zhang_Initial_State .FOUND
638- elif new_value <= (
639- current_value + varipeps_config .line_search_hager_zhang_eps
640- ):
644+ elif new_value <= (current_value + hager_zhang_eps ):
641645 hager_zhang_lower_bound = alpha
642646 hager_zhang_lower_bound_value = new_value
643647 hager_zhang_lower_bound_grad = new_gradient
@@ -700,9 +704,7 @@ def line_search(
700704 hager_zhang_upper_bound_grad = new_gradient
701705 hager_zhang_upper_bound_des_grad = descent_new_grad
702706 hager_zhang_initial_found = _Hager_Zhang_Initial_State .FOUND
703- elif descent_new_grad < 0 and new_value > (
704- current_value + varipeps_config .line_search_hager_zhang_eps
705- ):
707+ elif descent_new_grad < 0 and new_value > (current_value + hager_zhang_eps ):
706708 alpha = varipeps_config .line_search_hager_zhang_theta * alpha
707709 hager_zhang_initial_found = (
708710 _Hager_Zhang_Initial_State .SCALAR_LOWER_VALUE_GREATER
@@ -725,9 +727,7 @@ def line_search(
725727 count += 1
726728 continue
727729 else :
728- if new_value <= (
729- current_value + varipeps_config .line_search_hager_zhang_eps
730- ):
730+ if new_value <= (current_value + hager_zhang_eps ):
731731 hager_zhang_lower_bound = alpha
732732 hager_zhang_lower_bound_value = new_value
733733 hager_zhang_lower_bound_grad = new_gradient
@@ -892,9 +892,7 @@ def line_search(
892892 hager_zhang_upper_bound_grad = new_gradient
893893 hager_zhang_upper_bound_des_grad = descent_new_grad
894894 hager_zhang_state = _Hager_Zhang_State .NONE
895- elif new_value <= (
896- current_value + varipeps_config .line_search_hager_zhang_eps
897- ):
895+ elif new_value <= (current_value + hager_zhang_eps ):
898896 hager_zhang_lower_bound = alpha
899897 hager_zhang_lower_bound_value = new_value
900898 hager_zhang_lower_bound_grad = new_gradient
@@ -938,9 +936,7 @@ def line_search(
938936 hager_zhang_upper_bound_grad = new_gradient
939937 hager_zhang_upper_bound_des_grad = descent_new_grad
940938 hager_zhang_state = _Hager_Zhang_State .NONE
941- elif new_value <= (
942- current_value + varipeps_config .line_search_hager_zhang_eps
943- ):
939+ elif new_value <= (current_value + hager_zhang_eps ):
944940 hager_zhang_lower_bound = alpha
945941 hager_zhang_lower_bound_value = new_value
946942 hager_zhang_lower_bound_grad = new_gradient
0 commit comments