|
3 | 3 |
|
4 | 4 | # This source code is licensed under the license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | | -import os |
7 | | -import sys |
8 | 6 |
|
9 | 7 | import torch |
10 | 8 |
|
11 | | -lm_evaluation_harness_path = "/".join( |
12 | | - os.getcwd().split("/")[:-1] + ["lm-evaluation-harness"] |
13 | | -) |
14 | | -sys.path.insert(0, lm_evaluation_harness_path) |
15 | | -import main as lm_evaluation_harness_main |
16 | 9 | import torch.fx as fx |
17 | 10 | import torch.nn as nn |
18 | 11 | import torch.nn.functional as F |
19 | 12 | from torch.utils._pytree import tree_flatten, tree_unflatten |
20 | 13 |
|
21 | | -from eval import setup_cache_padded_seq_input_pos_max_seq_length_for_prefill |
22 | | -from generate import encode_tokens |
23 | | - |
24 | 14 | aten = torch.ops.aten |
25 | 15 |
|
26 | | -try: |
27 | | - import lm_eval |
28 | | - class InputRecorder(lm_eval.base.BaseLM): |
29 | | - """ |
30 | | - This is a fake evaluation wrapper that just records the inputs |
31 | | - so that they can be used in calibration. |
32 | | -
|
33 | | - If pad_calibration_inputs is enabled, the input recorder will take |
34 | | - each input and pad/truncate it down to the calibration_seq_length. |
35 | | - It will also edit the model embeddings to be zero for the 0 token used |
36 | | - in padding and avoid any inputs with the 0 token. |
37 | | -
|
38 | | - If not, it will only truncate inputs to the desired length. |
39 | | - """ |
40 | | - |
41 | | - def __init__( |
42 | | - self, |
43 | | - model, |
44 | | - tokenizer, |
45 | | - calibration_seq_length, |
46 | | - pad_calibration_inputs=False, |
47 | | - ): |
48 | | - super().__init__() |
49 | | - self._model = model |
50 | | - self._tokenizer = tokenizer |
51 | | - self._device = torch.device("cpu") |
52 | | - self.vocab_size = model.config.vocab_size |
53 | | - self.calibration_seq_length = calibration_seq_length |
54 | | - self.pad_calibration_inputs = pad_calibration_inputs |
55 | | - self.inputs = None |
56 | | - |
57 | | - if self.pad_calibration_inputs: |
58 | | - # This is needed for the pad_calibration_inputs option |
59 | | - # to work properly, the 0 token's embeddings are set to 0 so that |
60 | | - # the padded inputs will not affect the model numerics. This token isn't used |
61 | | - # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs |
62 | | - # where it appears |
63 | | - try: |
64 | | - if isinstance(self._model.transformer.wte, nn.Embedding): |
65 | | - self.mod.transformer.wte.weight.data[0, :] *= 0 |
66 | | - except: |
67 | | - print( |
68 | | - "Did not find embeddings in model.transformer.wte, disabling padding" |
69 | | - ) |
70 | | - self.pad_calibration_inputs = False |
| 16 | +from eval import ( |
| 17 | + setup_cache_padded_seq_input_pos_max_seq_length_for_prefill, |
| 18 | + encode_tokens, |
| 19 | + eval_wrapper |
| 20 | +) |
71 | 21 |
|
72 | | - @property |
73 | | - def eot_token_id(self): |
74 | | - return self._tokenizer.eos_id() |
75 | 22 |
|
76 | | - @property |
77 | | - def max_length(self): |
78 | | - return self.calibration_seq_length |
| 23 | +class InputRecorder(eval_wrapper): |
| 24 | + """ |
| 25 | + This is a fake evaluation wrapper that just records the inputs |
| 26 | + so that they can be used in calibration. |
79 | 27 |
|
80 | | - @property |
81 | | - def max_gen_toks(self): |
82 | | - return 50 |
| 28 | + If pad_calibration_inputs is enabled, the input recorder will take |
| 29 | + each input and pad/truncate it down to the calibration_seq_length. |
| 30 | + It will also edit the model embeddings to be zero for the 0 token used |
| 31 | + in padding and avoid any inputs with the 0 token. |
83 | 32 |
|
84 | | - @property |
85 | | - def batch_size(self): |
86 | | - return 1 |
| 33 | + If not, it will only truncate inputs to the desired length. |
| 34 | + """ |
87 | 35 |
|
88 | | - @property |
89 | | - def device(self): |
90 | | - return self._device |
| 36 | + def __init__( |
| 37 | + self, |
| 38 | + model, |
| 39 | + tokenizer, |
| 40 | + calibration_seq_length, |
| 41 | + pad_calibration_inputs=False, |
| 42 | + ): |
| 43 | + super().__init__() |
| 44 | + self._model = model |
| 45 | + self._tokenizer = tokenizer |
| 46 | + self._device = torch.device("cpu") |
| 47 | + self.vocab_size = model.config.vocab_size |
| 48 | + self.calibration_seq_length = calibration_seq_length |
| 49 | + self.pad_calibration_inputs = pad_calibration_inputs |
| 50 | + self.inputs = None |
| 51 | + |
| 52 | + if self.pad_calibration_inputs: |
| 53 | + # This is needed for the pad_calibration_inputs option |
| 54 | + # to work properly, the 0 token's embeddings are set to 0 so that |
| 55 | + # the padded inputs will not affect the model numerics. This token isn't used |
| 56 | + # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs |
| 57 | + # where it appears |
| 58 | + try: |
| 59 | + if isinstance(self._model.transformer.wte, nn.Embedding): |
| 60 | + self.mod.transformer.wte.weight.data[0, :] *= 0 |
| 61 | + except: |
| 62 | + print( |
| 63 | + "Did not find embeddings in model.transformer.wte, disabling padding" |
| 64 | + ) |
| 65 | + self.pad_calibration_inputs = False |
91 | 66 |
|
92 | | - def tok_encode(self, string: str): |
93 | | - encoded = encode_tokens( |
94 | | - self._tokenizer, string, bos=True, device=self._device |
95 | | - ) |
96 | | - # encoded is a pytorch tensor, but some internal logic in the |
97 | | - # eval harness expects it to be a list instead |
98 | | - # TODO: verify this for multi-batch as well |
99 | | - encoded = encoded.tolist() |
100 | | - return encoded |
101 | | - |
102 | | - def tok_decode(self, tokens): |
103 | | - decoded = self._tokenizer.decode(tokens) |
104 | | - return decoded |
105 | | - |
106 | | - def add_input(self, args): |
107 | | - if self.inputs is None: |
108 | | - self.inputs = [MultiInput([arg]) for arg in args] |
109 | | - else: |
110 | | - self.inputs = [ |
111 | | - multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) |
112 | | - ] |
| 67 | + @property |
| 68 | + def eot_token_id(self): |
| 69 | + return self._tokenizer.eos_id() |
113 | 70 |
|
114 | | - def get_recorded_inputs(self): |
115 | | - return self.inputs |
| 71 | + @property |
| 72 | + def max_length(self): |
| 73 | + return self.calibration_seq_length |
116 | 74 |
|
117 | | - def _model_call(self, inps): |
118 | | - inps = inps.squeeze(0) |
119 | | - T = len(inps) |
120 | | - if ( |
121 | | - # can't use inputs that are too short when padding disabled |
122 | | - (T < self.calibration_seq_length and not self.pad_calibration_inputs) |
123 | | - or |
124 | | - # can't use inputs that actually use token we use for padding |
125 | | - (self.pad_calibration_inputs and 0 in inps) |
126 | | - ): |
127 | | - # give random output |
128 | | - return torch.randn( |
129 | | - (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device |
130 | | - ) |
| 75 | + @property |
| 76 | + def max_gen_toks(self): |
| 77 | + return 50 |
131 | 78 |
|
132 | | - # pad or truncate to the right size |
133 | | - if T >= self.calibration_seq_length: |
134 | | - inps = inps[: self.calibration_seq_length] |
135 | | - else: |
136 | | - inps = F.pad(inps, (0, self.calibration_seq_length - T)) |
137 | | - |
138 | | - max_new_tokens = 1 |
139 | | - ( |
140 | | - seq, |
141 | | - input_pos, |
142 | | - max_seq_length, |
143 | | - ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( |
144 | | - self._model, inps, max_new_tokens, self.max_length |
145 | | - ) |
146 | | - x = seq.index_select(0, input_pos).view(1, -1) |
147 | | - self.add_input((x, input_pos)) |
| 79 | + @property |
| 80 | + def batch_size(self): |
| 81 | + return 1 |
148 | 82 |
|
149 | | - # output `something` with correct shape to keep eval going |
| 83 | + @property |
| 84 | + def device(self): |
| 85 | + return self._device |
| 86 | + |
| 87 | + def tok_encode(self, string: str): |
| 88 | + encoded = encode_tokens( |
| 89 | + self._tokenizer, string, bos=True, device=self._device |
| 90 | + ) |
| 91 | + # encoded is a pytorch tensor, but some internal logic in the |
| 92 | + # eval harness expects it to be a list instead |
| 93 | + # TODO: verify this for multi-batch as well |
| 94 | + encoded = encoded.tolist() |
| 95 | + return encoded |
| 96 | + |
| 97 | + def tok_decode(self, tokens): |
| 98 | + decoded = self._tokenizer.decode(tokens) |
| 99 | + return decoded |
| 100 | + |
| 101 | + def add_input(self, args): |
| 102 | + if self.inputs is None: |
| 103 | + self.inputs = [MultiInput([arg]) for arg in args] |
| 104 | + else: |
| 105 | + self.inputs = [ |
| 106 | + multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) |
| 107 | + ] |
| 108 | + |
| 109 | + def get_recorded_inputs(self): |
| 110 | + return self.inputs |
| 111 | + |
| 112 | + def _model_call(self, inps): |
| 113 | + inps = inps.squeeze(0) |
| 114 | + T = len(inps) |
| 115 | + if ( |
| 116 | + # can't use inputs that are too short when padding disabled |
| 117 | + (T < self.calibration_seq_length and not self.pad_calibration_inputs) |
| 118 | + or |
| 119 | + # can't use inputs that actually use token we use for padding |
| 120 | + (self.pad_calibration_inputs and 0 in inps) |
| 121 | + ): |
| 122 | + # give random output |
150 | 123 | return torch.randn( |
151 | 124 | (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device |
152 | 125 | ) |
153 | 126 |
|
154 | | - def _model_generate(self, context, max_length, eos_token_id): |
155 | | - raise Exception("unimplemented") |
156 | | -except ImportError: |
157 | | - pass |
| 127 | + # pad or truncate to the right size |
| 128 | + if T >= self.calibration_seq_length: |
| 129 | + inps = inps[: self.calibration_seq_length] |
| 130 | + else: |
| 131 | + inps = F.pad(inps, (0, self.calibration_seq_length - T)) |
| 132 | + |
| 133 | + max_new_tokens = 1 |
| 134 | + ( |
| 135 | + seq, |
| 136 | + input_pos, |
| 137 | + max_seq_length, |
| 138 | + ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( |
| 139 | + self._model, inps, max_new_tokens, self.max_length |
| 140 | + ) |
| 141 | + x = seq.index_select(0, input_pos).view(1, -1) |
| 142 | + self.add_input((x, input_pos)) |
| 143 | + |
| 144 | + # output `something` with correct shape to keep eval going |
| 145 | + return torch.randn( |
| 146 | + (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device |
| 147 | + ) |
| 148 | + |
| 149 | + def _model_generate(self, context, max_length, eos_token_id): |
| 150 | + raise Exception("unimplemented") |
158 | 151 |
|
159 | 152 |
|
160 | 153 | class MultiInput: |
|
0 commit comments