@@ -19,6 +19,7 @@ def setUpClass(cls):
1919 cls .device = DEVICE
2020
2121 cls .config , cls .tokenizer , cls .tinystory = pv .create_gpt_neo ()
22+ cls .tinystory .to (cls .device )
2223
2324 @classmethod
2425 def tearDownClass (cls ):
@@ -65,7 +66,7 @@ def test_lm_generation(self):
6566
6667 prompt = tokenizer ("Once upon a time there was" , return_tensors = "pt" )
6768 _ , intervened_story = pv_tinystory .generate (
68- prompt , source_representations = emb_happy , max_length = 32
69+ prompt , source_representations = emb_happy , unit_locations = { "sources->base" : ( 0 , [ 0 , 1 , 2 ])}, max_length = 32
6970 )
7071 print (tokenizer .decode (intervened_story [0 ], skip_special_tokens = True ))
7172
@@ -81,7 +82,7 @@ def test_generation_with_source_intervened_prompt(self):
8182 }
8283 for l in range (self .config .num_layers )
8384 ],
84- model = self .tinystory . to ( self . device ) ,
85+ model = self .tinystory ,
8586 )
8687
8788 prompt = self .tokenizer ("Once upon a time there was" , return_tensors = "pt" ).to (
@@ -118,7 +119,7 @@ def test_dynamic_static_generation_intervention_parity(self):
118119 }
119120 for l in range (self .config .num_layers )
120121 ],
121- model = self .tinystory . to ( self . device ) ,
122+ model = self .tinystory ,
122123 )
123124
124125 prompt = self .tokenizer ("Once upon a time there was" , return_tensors = "pt" ).to (
@@ -144,6 +145,92 @@ def test_dynamic_static_generation_intervention_parity(self):
144145 orig_text != intervened_text
145146 ), "Aggressive intervention did not change the output. Probably something wrong."
146147
148+ def test_generation_noops (self ):
149+ torch .manual_seed (0 )
150+
151+ # No-op intervention
152+ pv_model = pv .IntervenableModel (
153+ [
154+ {
155+ "layer" : l ,
156+ "component" : "mlp_output" ,
157+ "intervention" : lambda b , s : b ,
158+ }
159+ for l in range (self .config .num_layers )
160+ ],
161+ model = self .tinystory ,
162+ )
163+
164+ prompt = self .tokenizer ("Once upon a time there was" , return_tensors = "pt" ).to (
165+ self .device
166+ )
167+ sources = self .tokenizer (" love" , return_tensors = "pt" ).to (self .device )
168+
169+ orig , intervened = pv_model .generate (
170+ prompt ,
171+ max_length = 20 ,
172+ sources = sources ,
173+ intervene_on_prompt = True ,
174+ unit_locations = {"sources->base" : (0 , [0 , 1 , 2 ])},
175+ output_original_output = True ,
176+ )
177+ orig_text , intervened_text = (
178+ self .tokenizer .decode (orig [0 ], skip_special_tokens = True ),
179+ self .tokenizer .decode (intervened [0 ], skip_special_tokens = True ),
180+ )
181+
182+ print (intervened_text )
183+ assert (
184+ orig_text == intervened_text
185+ ), "No-op intervention changed the output. Probably something wrong."
186+
187+ # Aggressive intervention with intervene_on_prompt=False
188+ aggressive_model = pv .IntervenableModel (
189+ [
190+ {
191+ "layer" : l ,
192+ "component" : "mlp_output" ,
193+ "intervention" : lambda b , s : s * 1000 ,
194+ }
195+ for l in range (self .config .num_layers )
196+ ],
197+ model = self .tinystory ,
198+ )
199+
200+ orig , intervened = aggressive_model .generate (
201+ prompt ,
202+ max_length = 20 ,
203+ sources = sources ,
204+ intervene_on_prompt = False ,
205+ output_original_output = True ,
206+ )
207+
208+ orig_text , intervened_text = (
209+ self .tokenizer .decode (orig [0 ], skip_special_tokens = True ),
210+ self .tokenizer .decode (intervened [0 ], skip_special_tokens = True ),
211+ )
212+ print (orig_text )
213+ print (intervened_text )
214+ assert (
215+ orig_text == intervened_text
216+ ), "Aggressive intervention changed the output. Probably something wrong."
217+
218+ # Aggressive intervention with no prompt intervention, disabled selectors
219+ orig , intervened = aggressive_model .generate (
220+ prompt ,
221+ max_length = 20 ,
222+ sources = sources ,
223+ intervene_on_prompt = False ,
224+ output_original_output = True ,
225+ timestep_selector = [lambda idx , o : False ] * self .config .num_layers ,
226+ )
227+ orig_text , intervened_text = (
228+ self .tokenizer .decode (orig [0 ], skip_special_tokens = True ),
229+ self .tokenizer .decode (intervened [0 ], skip_special_tokens = True ),
230+ )
231+ assert (
232+ orig_text == intervened_text
233+ ), "Aggressive intervention changed the output. Probably something wrong."
147234
148235if __name__ == "__main__" :
149236 unittest .main ()
0 commit comments