File tree Expand file tree Collapse file tree 1 file changed +16
-14
lines changed Expand file tree Collapse file tree 1 file changed +16
-14
lines changed Original file line number Diff line number Diff line change 1414
1515import numpy as np
1616import torch
17-
17+ try :
18+ # NOTE opt_einsum needed to avoid blowing up memory with einsum ops
19+ import opt_einsum
20+ import torch .backends .opt_einsum
21+ torch .backends .opt_einsum .enabled = True
22+ torch .backends .opt_einsum .strategy = "auto-hq"
23+ has_opt_einsum = True
24+ except ImportError :
25+ has_opt_einsum = False
1826
1927try :
2028 torch ._dynamo .config .cache_size_limit = 1_000_000
2634
2735
2836def precond_update_prob_schedule (
29- n : float ,
30- max_prob : float = 1.0 ,
31- min_prob : float = 0.03 ,
32- decay : float = 0.001 ,
33- flat_start : float = 500 ,
37+ n : float ,
38+ max_prob : float = 1.0 ,
39+ min_prob : float = 0.03 ,
40+ decay : float = 0.001 ,
41+ flat_start : float = 500 ,
3442) -> torch .Tensor :
3543 """Anneal preconditioner update probability during beginning of training.
3644
@@ -92,14 +100,8 @@ def __init__(
92100 flatten_dim : bool = False ,
93101 deterministic : bool = False ,
94102 ):
95- try :
96- # NOTE opt_einsum needed to avoid blowing up memory with einsum ops
97- import opt_einsum
98- opt_einsum .enabled = True
99- opt_einsum .strategy = "auto-hq"
100- import torch .backends .opt_einsum
101- except ImportError :
102- warnings .warn ("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
103+ if not has_opt_einsum :
104+ warnings .warn ("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
103105
104106 if not 0.0 <= lr :
105107 raise ValueError (f"Invalid learning rate: { lr } " )
You can’t perform that action at this time.
0 commit comments