File tree Expand file tree Collapse file tree 2 files changed +11
-3
lines changed Expand file tree Collapse file tree 2 files changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -141,8 +141,11 @@ def lion(
141141 r"""Functional API that performs Lion algorithm computation.
142142 """
143143 if foreach is None :
144- # Placeholder for more complex foreach logic to be added when value is not set
145- foreach = True
144+ try :
145+ # cannot do foreach if this overload doesn't exist when caution enabled
146+ foreach = not caution or 'Scalar' in torch .ops .aten ._foreach_maximum .overloads ()
147+ except :
148+ foreach = False
146149
147150 if foreach and torch .jit .is_scripting ():
148151 raise RuntimeError ('torch.jit.script not supported with foreach optimizers' )
Original file line number Diff line number Diff line change @@ -169,7 +169,12 @@ def nadamw(
169169 ' singleton tensors' )
170170
171171 if foreach is None :
172- foreach = True
172+ try :
173+ # cannot do foreach if this overload doesn't exist when caution enabled
174+ foreach = not caution or 'Scalar' in torch .ops .aten ._foreach_maximum .overloads ()
175+ except :
176+ foreach = False
177+
173178 if foreach and not torch .jit .is_scripting ():
174179 func = _multi_tensor_nadamw
175180 else :
You can’t perform that action at this time.
0 commit comments