Skip to content

Commit 86edc10

Browse files
committed
fix negative_subspace, add intervention nulling with locations
1 parent 0cdd268 commit 86edc10

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

tests/integration_tests/IntervenableBasicTestCase.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,54 @@ def test_customized_intervention_function_zeroout(self):
646646
intervened_outputs_fn[1].last_hidden_state,
647647
)
648648

649+
def test_nulling_intervention(self):
650+
651+
_, tokenizer, gpt2 = pv.create_gpt2()
652+
gpt2.to("cuda")
653+
base = tokenizer(
654+
["The capital of Spain is" for i in range(3)], return_tensors="pt"
655+
).to("cuda")
656+
657+
base_output = gpt2(**base)
658+
base_logits = pv.embed_to_distrib(
659+
gpt2, base_output.last_hidden_state, logits=True
660+
)[0]
661+
print(base_logits.shape)
662+
663+
pv_gpt2 = pv.IntervenableModel(
664+
{
665+
"layer": 0,
666+
"component": "mlp_output",
667+
"intervention": lambda b, s: b * 0.5 + s * 0.5,
668+
},
669+
model=gpt2,
670+
)
671+
pv_gpt2.set_device("cuda")
672+
673+
_, intervened_outputs = pv_gpt2(
674+
# the base input
675+
base=base,
676+
# the source input
677+
sources=tokenizer(["Egypt" for i in range(3)], return_tensors="pt").to(
678+
"cuda"
679+
),
680+
# the location to intervene at (3rd token)
681+
unit_locations={"sources->base": (0, [[[3], None, [3]]])},
682+
)
683+
684+
intervened_logits = pv.embed_to_distrib(
685+
gpt2, intervened_outputs.last_hidden_state, logits=True
686+
)
687+
assert not torch.allclose(
688+
base_logits, intervened_logits[0]
689+
), "Intervention had no effect on example 0!"
690+
assert torch.allclose(
691+
base_logits, intervened_logits[1]
692+
), "Intervention was not nulled on example 1!"
693+
assert not torch.allclose(
694+
base_logits, intervened_logits[2]
695+
), "Intervention had no effect on example 2!"
696+
649697
@classmethod
650698
def tearDownClass(cls):
651699
print(f"Removing testing dir {cls._test_dir}")

tests/integration_tests/InterventionWithMLPTestCase.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,35 @@ def setUpClass(cls):
4242
intervention_types=VanillaIntervention,
4343
)
4444

45+
cls.test_negative_subspace_config = IntervenableConfig(
46+
model_type=type(cls.mlp),
47+
representations=[
48+
RepresentationConfig(
49+
0,
50+
"mlp_activation",
51+
"pos", # mlp layer creates a single token reprs
52+
1,
53+
subspace_partition=[
54+
[1, 4],
55+
[0, 1],
56+
], # partition into two sets of subspaces
57+
intervention_link_key=0, # linked ones target the same subspace
58+
),
59+
RepresentationConfig(
60+
0,
61+
"mlp_activation",
62+
"pos", # mlp layer creates a single token reprs
63+
1,
64+
subspace_partition=[
65+
[1, 4],
66+
[0, 1],
67+
], # partition into two sets of subspaces
68+
intervention_link_key=0, # linked ones target the same subspace
69+
),
70+
],
71+
intervention_types=VanillaIntervention,
72+
)
73+
4574
cls.test_subspace_no_intervention_link_config = (
4675
IntervenableConfig(
4776
model_type=type(cls.mlp),
@@ -149,13 +178,12 @@ def test_with_subspace_negative(self):
149178
Negative test case to check input length.
150179
"""
151180
intervenable = IntervenableModel(
152-
self.test_subspace_intervention_link_config, self.mlp
181+
self.test_negative_subspace_config, self.mlp
153182
)
154183
# golden label
155184
b_s = 10
156185
base = {"inputs_embeds": torch.rand(b_s, 1, 3)}
157186
source_1 = {"inputs_embeds": torch.rand(b_s, 1, 3)}
158-
source_2 = {"inputs_embeds": torch.rand(b_s, 1, 3)}
159187

160188
try:
161189
intervenable(

0 commit comments

Comments
 (0)