Skip to content

Commit 7f099fb

Browse files
committed
intervention doesnt delete your passed reprs, also generation notebook
1 parent 0e3ecb2 commit 7f099fb

File tree

4 files changed

+335
-37
lines changed

4 files changed

+335
-37
lines changed

pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@
1919
"mlp_activation": ("h[%s].mlp.act", CONST_OUTPUT_HOOK),
2020
"mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK),
2121
"mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK),
22-
"attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK),
23-
"head_attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
22+
"attention_value_output": ("h[%s].attn.attention.out_proj", CONST_INPUT_HOOK),
23+
"head_attention_value_output": ("h[%s].attn.attention.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
2424
"attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK),
2525
"attention_input": ("h[%s].attn", CONST_INPUT_HOOK),
26-
"query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK),
27-
"key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK),
28-
"value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK),
29-
"head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
30-
"head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
31-
"head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
26+
"query_output": ("h[%s].attn.attention.q_proj", CONST_OUTPUT_HOOK),
27+
"key_output": ("h[%s].attn.attention.k_proj", CONST_OUTPUT_HOOK),
28+
"value_output": ("h[%s].attn.attention.v_proj", CONST_OUTPUT_HOOK),
29+
"head_query_output": ("h[%s].attn.attention.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
30+
"head_key_output": ("h[%s].attn.attention.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
31+
"head_value_output": ("h[%s].attn.attention.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
3232
}
3333

3434

pyvene/models/intervenable_base.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -856,10 +856,7 @@ def _intervention_setter(
856856
] # batch_size
857857

858858
def hook_callback(model, args, kwargs, output=None):
859-
if (
860-
self._skip_forward
861-
and state.setter_timestep <= 0
862-
):
859+
if self._skip_forward and state.setter_timestep <= 0:
863860
state.setter_timestep += 1
864861
return
865862

@@ -881,7 +878,8 @@ def hook_callback(model, args, kwargs, output=None):
881878
else [
882879
(
883880
[0]
884-
if timestep_selector != None and timestep_selector[key_i](
881+
if timestep_selector != None
882+
and timestep_selector[key_i](
885883
state.setter_timestep, output[i]
886884
)
887885
else None
@@ -1054,11 +1052,11 @@ def _wait_for_forward_with_parallel_intervention(
10541052
group_get_handlers.remove()
10551053
else:
10561054
# simply patch in the ones passed in
1057-
self.activations = activations_sources
1058-
for _, passed_in_key in enumerate(self.activations):
1055+
for passed_in_key, v in activations_sources.items():
10591056
assert (
10601057
passed_in_key in self.sorted_keys
10611058
), f"{passed_in_key} not in {self.sorted_keys}, {unit_locations}"
1059+
self.activations[passed_in_key] = torch.clone(v)
10621060

10631061
# in parallel mode, we swap cached activations all into
10641062
# base at once
@@ -1094,17 +1092,25 @@ def _wait_for_forward_with_serial_intervention(
10941092
if sources[group_id] is None:
10951093
continue # smart jump for advance usage only
10961094

1097-
group_dest = "base" if group_id >= len(self._intervention_group) - 1 else f"source_{group_id+1}"
1098-
group_key = f'source_{group_id}->{group_dest}'
1095+
group_dest = (
1096+
"base"
1097+
if group_id >= len(self._intervention_group) - 1
1098+
else f"source_{group_id+1}"
1099+
)
1100+
group_key = f"source_{group_id}->{group_dest}"
10991101
unit_locations_source = unit_locations[group_key][0]
11001102
unit_locations_base = unit_locations[group_key][1]
11011103

11021104
if activations_sources != None:
11031105
for key in keys:
11041106
self.activations[key] = activations_sources[key]
11051107
else:
1106-
keys_with_source = [k for i, k in enumerate(keys) if unit_locations_source[i] != None]
1107-
get_handlers = self._intervention_getter(keys_with_source, unit_locations_source)
1108+
keys_with_source = [
1109+
k for i, k in enumerate(keys) if unit_locations_source[i] != None
1110+
]
1111+
get_handlers = self._intervention_getter(
1112+
keys_with_source, unit_locations_source
1113+
)
11081114

11091115
# call once per group. each intervention is by its own group by default
11101116
if activations_sources is None:
@@ -1402,11 +1408,11 @@ def forward(
14021408

14031409
self._output_validation()
14041410

1405-
collected_activations = []
1411+
collected_activations = {}
14061412
if self.return_collect_activations:
14071413
for key in self.sorted_keys:
14081414
if isinstance(self.interventions[key][0], CollectIntervention):
1409-
collected_activations += self.activations[key]
1415+
collected_activations[key] = self.activations[key]
14101416

14111417
except Exception as e:
14121418
raise e
@@ -1439,15 +1445,16 @@ def generate(
14391445
self,
14401446
base,
14411447
sources: Optional[List] = None,
1448+
source_representations: Optional[Dict] = None,
1449+
intervene_on_prompt: bool = True,
14421450
unit_locations: Optional[Dict] = None,
14431451
timestep_selector: Optional[TIMESTEP_SELECTOR_TYPE] = None,
1444-
intervene_on_prompt: bool = True,
1445-
source_representations: Optional[Dict] = None,
14461452
subspaces: Optional[List] = None,
14471453
output_original_output: Optional[bool] = False,
14481454
**kwargs,
14491455
) -> Tuple[
1450-
ModelOutput | Tuple[ModelOutput | None, List[torch.Tensor]] | None, ModelOutput
1456+
Optional[ModelOutput | Tuple[Optional[ModelOutput], Dict[str, torch.Tensor]]],
1457+
ModelOutput,
14511458
]:
14521459
"""
14531460
Intervenable generation function that serves a
@@ -1532,11 +1539,11 @@ def generate(
15321539
# run intervened generate
15331540
counterfactual_outputs = self.model.generate(**base, **kwargs)
15341541

1535-
collected_activations = []
1542+
collected_activations = {}
15361543
if self.return_collect_activations:
15371544
for key in self.sorted_keys:
15381545
if isinstance(self.interventions[key][0], CollectIntervention):
1539-
collected_activations += self.activations[key]
1546+
collected_activations[key] = self.activations[key]
15401547
except Exception as e:
15411548
raise e
15421549
finally:

pyvene_101.ipynb

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,18 @@
160160
},
161161
{
162162
"cell_type": "code",
163-
"execution_count": 19,
163+
"execution_count": 5,
164164
"id": "128be2dd-f089-4291-bfc5-7002d031b1e9",
165-
"metadata": {},
165+
"metadata": {
166+
"metadata": {}
167+
},
166168
"outputs": [
167169
{
168170
"name": "stdout",
169171
"output_type": "stream",
170172
"text": [
171-
"loaded model\n"
173+
"loaded GPT2 model gpt2\n",
174+
"torch.Size([12, 14, 14])\n"
172175
]
173176
}
174177
],
@@ -186,10 +189,11 @@
186189
" \"intervention\": pv.CollectIntervention()}, model=gpt2)\n",
187190
"\n",
188191
"base = \"When John and Mary went to the shops, Mary gave the bag to\"\n",
189-
"collected_attn_w = pv_gpt2(\n",
192+
"(_, collected_attn_w), _ = pv_gpt2(\n",
190193
" base = tokenizer(base, return_tensors=\"pt\"\n",
191194
" ), unit_locations={\"base\": [h for h in range(12)]}\n",
192-
")[0][-1][0]"
195+
")\n",
196+
"print(collected_attn_w[0].shape)"
193197
]
194198
},
195199
{
@@ -1753,15 +1757,31 @@
17531757
},
17541758
{
17551759
"cell_type": "code",
1756-
"execution_count": 28,
1760+
"execution_count": 1,
17571761
"id": "61cd8fc9",
1758-
"metadata": {},
1762+
"metadata": {
1763+
"metadata": {}
1764+
},
17591765
"outputs": [
17601766
{
17611767
"name": "stdout",
17621768
"output_type": "stream",
17631769
"text": [
1764-
"loaded model\n",
1770+
"loaded GPT2 model gpt2\n"
1771+
]
1772+
},
1773+
{
1774+
"name": "stderr",
1775+
"output_type": "stream",
1776+
"text": [
1777+
"/juice/scr/nathangk/text-intervention/pyvene/pyvene/models/intervenable_base.py:796: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1778+
" cached_activations = torch.tensor(self.activations[key])\n"
1779+
]
1780+
},
1781+
{
1782+
"name": "stdout",
1783+
"output_type": "stream",
1784+
"text": [
17651785
"True True\n"
17661786
]
17671787
}
@@ -2017,9 +2037,11 @@
20172037
},
20182038
{
20192039
"cell_type": "code",
2020-
"execution_count": 31,
2040+
"execution_count": 2,
20212041
"id": "acce6e8f",
2022-
"metadata": {},
2042+
"metadata": {
2043+
"metadata": {}
2044+
},
20232045
"outputs": [
20242046
{
20252047
"name": "stderr",
@@ -2028,6 +2050,34 @@
20282050
"You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
20292051
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
20302052
]
2053+
},
2054+
{
2055+
"data": {
2056+
"application/vnd.jupyter.widget-view+json": {
2057+
"model_id": "517de63768da4f7f8f58e5018c6f75f6",
2058+
"version_major": 2,
2059+
"version_minor": 0
2060+
},
2061+
"text/plain": [
2062+
"model.safetensors: 0%| | 0.00/308M [00:00<?, ?B/s]"
2063+
]
2064+
},
2065+
"metadata": {},
2066+
"output_type": "display_data"
2067+
},
2068+
{
2069+
"data": {
2070+
"application/vnd.jupyter.widget-view+json": {
2071+
"model_id": "378463b4bd174067b6269d64a1ddf1fe",
2072+
"version_major": 2,
2073+
"version_minor": 0
2074+
},
2075+
"text/plain": [
2076+
"generation_config.json: 0%| | 0.00/147 [00:00<?, ?B/s]"
2077+
]
2078+
},
2079+
"metadata": {},
2080+
"output_type": "display_data"
20312081
}
20322082
],
20332083
"source": [
@@ -2676,7 +2726,7 @@
26762726
"name": "python",
26772727
"nbconvert_exporter": "python",
26782728
"pygments_lexer": "ipython3",
2679-
"version": "3.10.13"
2729+
"version": "3.11.7"
26802730
},
26812731
"toc-autonumbering": true,
26822732
"toc-showcode": false,

0 commit comments

Comments
 (0)