File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -222,14 +222,15 @@ def forward(self, x: Tensor) -> Tensor:
222222
223223
224224def precompute_freqs_cis (
225- seq_len : int , n_elem : int , base : int = 10000
225+ seq_len : int , n_elem : int , base : int = 10000 ,
226+ dtype : torch .dtype = torch .bfloat16
226227) -> Tensor :
227228 freqs = 1.0 / (base ** (torch .arange (0 , n_elem , 2 )[: (n_elem // 2 )].float () / n_elem ))
228229 t = torch .arange (seq_len , device = freqs .device )
229230 freqs = torch .outer (t , freqs )
230231 freqs_cis = torch .polar (torch .ones_like (freqs ), freqs )
231232 cache = torch .stack ([freqs_cis .real , freqs_cis .imag ], dim = - 1 )
232- return cache .to (dtype = torch . bfloat16 )
233+ return cache .to (dtype = dtype )
233234
234235
235236def apply_rotary_emb (x : Tensor , freqs_cis : Tensor ) -> Tensor :
You can’t perform that action at this time.
0 commit comments