Skip to content

Commit 0cdd268

Browse files
committed
noop guards
1 parent 78ba2f6 commit 0cdd268

File tree

3 files changed

+94
-9
lines changed

3 files changed

+94
-9
lines changed

pyvene/models/intervenable_base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def _cleanup_states(self, skip_activation_gc=False):
270270
"""
271271
Clean up all old in memo states of interventions
272272
"""
273-
self._skip_forward = False
274273
self._remove_forward_hooks()
275274
self._reset_hook_count()
276275
if not skip_activation_gc:
@@ -857,8 +856,7 @@ def _intervention_setter(
857856

858857
def hook_callback(model, args, kwargs, output=None):
859858
if (
860-
not self.is_model_stateless
861-
and self._skip_forward
859+
self._skip_forward
862860
and state.setter_timestep <= 0
863861
):
864862
state.setter_timestep += 1
@@ -878,11 +876,11 @@ def hook_callback(model, args, kwargs, output=None):
878876
# in this code we assume that output is batched along its first axis.
879877
int_unit_loc = (
880878
unit_locations_base[key_i]
881-
if state.setter_timestep <= 0 or not timestep_selector
879+
if state.setter_timestep <= 0
882880
else [
883881
(
884882
[0]
885-
if timestep_selector[key_i](
883+
if timestep_selector != None and timestep_selector[key_i](
886884
state.setter_timestep, output[i]
887885
)
888886
else None

tests/integration_tests/GenerationInterventionTestCase.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def setUpClass(cls):
1919
cls.device = DEVICE
2020

2121
cls.config, cls.tokenizer, cls.tinystory = pv.create_gpt_neo()
22+
cls.tinystory.to(cls.device)
2223

2324
@classmethod
2425
def tearDownClass(cls):
@@ -65,7 +66,7 @@ def test_lm_generation(self):
6566

6667
prompt = tokenizer("Once upon a time there was", return_tensors="pt")
6768
_, intervened_story = pv_tinystory.generate(
68-
prompt, source_representations=emb_happy, max_length=32
69+
prompt, source_representations=emb_happy, unit_locations={"sources->base": (0, [0, 1, 2])}, max_length=32
6970
)
7071
print(tokenizer.decode(intervened_story[0], skip_special_tokens=True))
7172

@@ -81,7 +82,7 @@ def test_generation_with_source_intervened_prompt(self):
8182
}
8283
for l in range(self.config.num_layers)
8384
],
84-
model=self.tinystory.to(self.device),
85+
model=self.tinystory,
8586
)
8687

8788
prompt = self.tokenizer("Once upon a time there was", return_tensors="pt").to(
@@ -118,7 +119,7 @@ def test_dynamic_static_generation_intervention_parity(self):
118119
}
119120
for l in range(self.config.num_layers)
120121
],
121-
model=self.tinystory.to(self.device),
122+
model=self.tinystory,
122123
)
123124

124125
prompt = self.tokenizer("Once upon a time there was", return_tensors="pt").to(
@@ -144,6 +145,92 @@ def test_dynamic_static_generation_intervention_parity(self):
144145
orig_text != intervened_text
145146
), "Aggressive intervention did not change the output. Probably something wrong."
146147

148+
def test_generation_noops(self):
149+
torch.manual_seed(0)
150+
151+
# No-op intervention
152+
pv_model = pv.IntervenableModel(
153+
[
154+
{
155+
"layer": l,
156+
"component": "mlp_output",
157+
"intervention": lambda b, s: b,
158+
}
159+
for l in range(self.config.num_layers)
160+
],
161+
model=self.tinystory,
162+
)
163+
164+
prompt = self.tokenizer("Once upon a time there was", return_tensors="pt").to(
165+
self.device
166+
)
167+
sources = self.tokenizer(" love", return_tensors="pt").to(self.device)
168+
169+
orig, intervened = pv_model.generate(
170+
prompt,
171+
max_length=20,
172+
sources=sources,
173+
intervene_on_prompt=True,
174+
unit_locations={"sources->base": (0, [0, 1, 2])},
175+
output_original_output=True,
176+
)
177+
orig_text, intervened_text = (
178+
self.tokenizer.decode(orig[0], skip_special_tokens=True),
179+
self.tokenizer.decode(intervened[0], skip_special_tokens=True),
180+
)
181+
182+
print(intervened_text)
183+
assert (
184+
orig_text == intervened_text
185+
), "No-op intervention changed the output. Probably something wrong."
186+
187+
# Aggressive intervention with intervene_on_prompt=False
188+
aggressive_model = pv.IntervenableModel(
189+
[
190+
{
191+
"layer": l,
192+
"component": "mlp_output",
193+
"intervention": lambda b, s: s * 1000,
194+
}
195+
for l in range(self.config.num_layers)
196+
],
197+
model=self.tinystory,
198+
)
199+
200+
orig, intervened = aggressive_model.generate(
201+
prompt,
202+
max_length=20,
203+
sources=sources,
204+
intervene_on_prompt=False,
205+
output_original_output=True,
206+
)
207+
208+
orig_text, intervened_text = (
209+
self.tokenizer.decode(orig[0], skip_special_tokens=True),
210+
self.tokenizer.decode(intervened[0], skip_special_tokens=True),
211+
)
212+
print(orig_text)
213+
print(intervened_text)
214+
assert (
215+
orig_text == intervened_text
216+
), "Aggressive intervention changed the output. Probably something wrong."
217+
218+
# Aggressive intervention with no prompt intervention, disabled selectors
219+
orig, intervened = aggressive_model.generate(
220+
prompt,
221+
max_length=20,
222+
sources=sources,
223+
intervene_on_prompt=False,
224+
output_original_output=True,
225+
timestep_selector=[lambda idx, o: False] * self.config.num_layers,
226+
)
227+
orig_text, intervened_text = (
228+
self.tokenizer.decode(orig[0], skip_special_tokens=True),
229+
self.tokenizer.decode(intervened[0], skip_special_tokens=True),
230+
)
231+
assert (
232+
orig_text == intervened_text
233+
), "Aggressive intervention changed the output. Probably something wrong."
147234

148235
if __name__ == "__main__":
149236
unittest.main()

tests/unit_tests/InterventionUtilsTestCase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def test_low_rank_gradient_positive(self):
468468
loss = F.mse_loss(output, golden)
469469
loss.backward()
470470
optimizer.step()
471-
print(output)
471+
472472
self.assertTrue(torch.allclose(golden, output, rtol=1e-02, atol=1e-02))
473473
except:
474474
pass # retry

0 commit comments

Comments
 (0)