File tree Expand file tree Collapse file tree 8 files changed +32
-34
lines changed Expand file tree Collapse file tree 8 files changed +32
-34
lines changed Original file line number Diff line number Diff line change 33
44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
6+ import os
7+ import sys
8+
69import torch
7- import os , sys
10+
811lm_evaluation_harness_path = "/" .join (
912 os .getcwd ().split ("/" )[:- 1 ] + ["lm-evaluation-harness" ]
1013)
1114sys .path .insert (0 , lm_evaluation_harness_path )
1215import main as lm_evaluation_harness_main
13-
1416import torch .fx as fx
1517import torch .nn as nn
1618import torch .nn .functional as F
Original file line number Diff line number Diff line change 99from typing import Optional
1010
1111import torch
12-
13- import torch ._inductor .config
1412import torch ._dynamo .config
13+ import torch ._inductor .config
14+
1515torch ._dynamo .config .automatic_dynamic_shapes = True
1616torch ._inductor .config .triton .unique_kernel_names = True
1717torch ._inductor .config .epilogue_fusion = False
2222wd = Path (__file__ ).parent .parent .resolve ()
2323sys .path .append (str (wd ))
2424
25- from model import LLaMA
26- from sentencepiece import SentencePieceProcessor
27-
2825# hacky path setup for lm-evaluation-harness
2926import os
3027import sys
28+
29+ from sentencepiece import SentencePieceProcessor
30+
31+ from model import LLaMA
32+
3133lm_evaluation_harness_path = '/' .join (
3234 os .getcwd ().split ('/' )[:- 1 ] + ['lm-evaluation-harness' ])
3335sys .path .insert (0 , lm_evaluation_harness_path )
34- import main as lm_evaluation_harness_main
3536import lm_eval
37+ import main as lm_evaluation_harness_main
3638
37- from generate import (
38- _load_model ,
39- encode_tokens ,
40- model_forward ,
41- )
39+ from generate import _load_model , encode_tokens , model_forward
4240
4341
4442def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill (
Original file line number Diff line number Diff line change 33
44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
6+ import itertools
67import sys
78import time
89from pathlib import Path
910from typing import Optional , Tuple
10- import itertools
11- import torch
1211
13- import torch . _inductor . config
12+ import torch
1413import torch ._dynamo .config
14+ import torch ._inductor .config
15+
1516torch ._inductor .config .coordinate_descent_tuning = True
1617torch ._inductor .config .triton .unique_kernel_names = True
1718torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2122wd = Path (__file__ ).parent .parent .resolve ()
2223sys .path .append (str (wd ))
2324
25+ from sentencepiece import SentencePieceProcessor
26+
2427from model import Transformer
2528from tp import maybe_init_dist
26- from sentencepiece import SentencePieceProcessor
29+
2730
2831def multinomial_sample_one_no_sync (probs_sort ): # Does multinomial sampling without a cuda synchronization
2932 q = torch .empty_like (probs_sort ).exponential_ (1 )
Original file line number Diff line number Diff line change 33
44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
6- import math
76from dataclasses import dataclass
87from typing import Optional
98
109import torch
1110import torch .nn as nn
12- from torch .nn import functional as F
1311from torch import Tensor
12+ from torch .nn import functional as F
13+
1414
1515def find_multiple (n : int , k : int ) -> int :
1616 if n % k == 0 :
Original file line number Diff line number Diff line change 33
44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
6- import importlib
76import time
8- from math import ceil
97from pathlib import Path
108
119import torch
12- import importlib
13- import time
14-
1510import torch .nn as nn
1611import torch .nn .functional as F
17-
18- from pathlib import Path
1912from sentencepiece import SentencePieceProcessor
2013
2114try :
Original file line number Diff line number Diff line change 44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66import json
7+ import re
78import sys
89from pathlib import Path
910from typing import Optional
1011
1112import torch
12- import re
1313
1414# support running without installing as a package
1515wd = Path (__file__ ).parent .parent .resolve ()
1616sys .path .append (str (wd ))
1717
1818from model import ModelArgs
1919
20+
2021@torch .inference_mode ()
2122def convert_hf_checkpoint (
2223 * ,
Original file line number Diff line number Diff line change 44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66import os
7- from requests .exceptions import HTTPError
8- import sys
9- from pathlib import Path
107from typing import Optional
118
9+ from requests .exceptions import HTTPError
10+
11+
1212def hf_download (repo_id : Optional [str ] = None , hf_token : Optional [str ] = None ) -> None :
1313 from huggingface_hub import snapshot_download
1414 os .makedirs (f"checkpoints/{ repo_id } " , exist_ok = True )
Original file line number Diff line number Diff line change 44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66import os
7- from typing import Optional , List
7+ from typing import List , Optional
88
99import torch
10- from torch import nn
1110import torch .distributed as dist
11+ from torch import nn
1212from torch .distributed import _functional_collectives as funcol
13- from model import Transformer , Attention , FeedForward
14- from quantize import WeightOnlyInt4Linear , WeightOnlyInt8Linear
13+
14+ from model import Attention , FeedForward , Transformer
15+ from quantize import WeightOnlyInt4Linear
1516
1617
1718def _get_rank () -> int :
You can’t perform that action at this time.
0 commit comments