Skip to content

Commit 9ee5d73

Browse files
committed
intervention doesnt delete your passed reprs, also generation notebook
1 parent cdd90de commit 9ee5d73

File tree

4 files changed

+335
-37
lines changed

4 files changed

+335
-37
lines changed

pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@
1919
"mlp_activation": ("h[%s].mlp.act", CONST_OUTPUT_HOOK),
2020
"mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK),
2121
"mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK),
22-
"attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK),
23-
"head_attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
22+
"attention_value_output": ("h[%s].attn.attention.out_proj", CONST_INPUT_HOOK),
23+
"head_attention_value_output": ("h[%s].attn.attention.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
2424
"attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK),
2525
"attention_input": ("h[%s].attn", CONST_INPUT_HOOK),
26-
"query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK),
27-
"key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK),
28-
"value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK),
29-
"head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
30-
"head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
31-
"head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
26+
"query_output": ("h[%s].attn.attention.q_proj", CONST_OUTPUT_HOOK),
27+
"key_output": ("h[%s].attn.attention.k_proj", CONST_OUTPUT_HOOK),
28+
"value_output": ("h[%s].attn.attention.v_proj", CONST_OUTPUT_HOOK),
29+
"head_query_output": ("h[%s].attn.attention.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
30+
"head_key_output": ("h[%s].attn.attention.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
31+
"head_value_output": ("h[%s].attn.attention.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
3232
}
3333

3434

pyvene/models/intervenable_base.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -855,10 +855,7 @@ def _intervention_setter(
855855
] # batch_size
856856

857857
def hook_callback(model, args, kwargs, output=None):
858-
if (
859-
self._skip_forward
860-
and state.setter_timestep <= 0
861-
):
858+
if self._skip_forward and state.setter_timestep <= 0:
862859
state.setter_timestep += 1
863860
return
864861

@@ -880,7 +877,8 @@ def hook_callback(model, args, kwargs, output=None):
880877
else [
881878
(
882879
[0]
883-
if timestep_selector != None and timestep_selector[key_i](
880+
if timestep_selector != None
881+
and timestep_selector[key_i](
884882
state.setter_timestep, output[i]
885883
)
886884
else None
@@ -1053,11 +1051,11 @@ def _wait_for_forward_with_parallel_intervention(
10531051
group_get_handlers.remove()
10541052
else:
10551053
# simply patch in the ones passed in
1056-
self.activations = activations_sources
1057-
for _, passed_in_key in enumerate(self.activations):
1054+
for passed_in_key, v in activations_sources.items():
10581055
assert (
10591056
passed_in_key in self.sorted_keys
10601057
), f"{passed_in_key} not in {self.sorted_keys}, {unit_locations}"
1058+
self.activations[passed_in_key] = torch.clone(v)
10611059

10621060
# in parallel mode, we swap cached activations all into
10631061
# base at once
@@ -1093,17 +1091,25 @@ def _wait_for_forward_with_serial_intervention(
10931091
if sources[group_id] is None:
10941092
continue # smart jump for advance usage only
10951093

1096-
group_dest = "base" if group_id >= len(self._intervention_group) - 1 else f"source_{group_id+1}"
1097-
group_key = f'source_{group_id}->{group_dest}'
1094+
group_dest = (
1095+
"base"
1096+
if group_id >= len(self._intervention_group) - 1
1097+
else f"source_{group_id+1}"
1098+
)
1099+
group_key = f"source_{group_id}->{group_dest}"
10981100
unit_locations_source = unit_locations[group_key][0]
10991101
unit_locations_base = unit_locations[group_key][1]
11001102

11011103
if activations_sources != None:
11021104
for key in keys:
11031105
self.activations[key] = activations_sources[key]
11041106
else:
1105-
keys_with_source = [k for i, k in enumerate(keys) if unit_locations_source[i] != None]
1106-
get_handlers = self._intervention_getter(keys_with_source, unit_locations_source)
1107+
keys_with_source = [
1108+
k for i, k in enumerate(keys) if unit_locations_source[i] != None
1109+
]
1110+
get_handlers = self._intervention_getter(
1111+
keys_with_source, unit_locations_source
1112+
)
11071113

11081114
# call once per group. each intervention is by its own group by default
11091115
if activations_sources is None:
@@ -1396,11 +1402,11 @@ def forward(
13961402

13971403
self._output_validation()
13981404

1399-
collected_activations = []
1405+
collected_activations = {}
14001406
if self.return_collect_activations:
14011407
for key in self.sorted_keys:
14021408
if isinstance(self.interventions[key][0], CollectIntervention):
1403-
collected_activations += self.activations[key]
1409+
collected_activations[key] = self.activations[key]
14041410

14051411
except Exception as e:
14061412
raise e
@@ -1433,15 +1439,16 @@ def generate(
14331439
self,
14341440
base,
14351441
sources: Optional[List] = None,
1442+
source_representations: Optional[Dict] = None,
1443+
intervene_on_prompt: bool = True,
14361444
unit_locations: Optional[Dict] = None,
14371445
timestep_selector: Optional[TIMESTEP_SELECTOR_TYPE] = None,
1438-
intervene_on_prompt: bool = True,
1439-
source_representations: Optional[Dict] = None,
14401446
subspaces: Optional[List] = None,
14411447
output_original_output: Optional[bool] = False,
14421448
**kwargs,
14431449
) -> Tuple[
1444-
ModelOutput | Tuple[ModelOutput | None, List[torch.Tensor]] | None, ModelOutput
1450+
Optional[ModelOutput | Tuple[Optional[ModelOutput], Dict[str, torch.Tensor]]],
1451+
ModelOutput,
14451452
]:
14461453
"""
14471454
Intervenable generation function that serves a
@@ -1526,11 +1533,11 @@ def generate(
15261533
# run intervened generate
15271534
counterfactual_outputs = self.model.generate(**base, **kwargs)
15281535

1529-
collected_activations = []
1536+
collected_activations = {}
15301537
if self.return_collect_activations:
15311538
for key in self.sorted_keys:
15321539
if isinstance(self.interventions[key][0], CollectIntervention):
1533-
collected_activations += self.activations[key]
1540+
collected_activations[key] = self.activations[key]
15341541
except Exception as e:
15351542
raise e
15361543
finally:

pyvene_101.ipynb

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,18 @@
160160
},
161161
{
162162
"cell_type": "code",
163-
"execution_count": 19,
163+
"execution_count": 5,
164164
"id": "128be2dd-f089-4291-bfc5-7002d031b1e9",
165-
"metadata": {},
165+
"metadata": {
166+
"metadata": {}
167+
},
166168
"outputs": [
167169
{
168170
"name": "stdout",
169171
"output_type": "stream",
170172
"text": [
171-
"loaded model\n"
173+
"loaded GPT2 model gpt2\n",
174+
"torch.Size([12, 14, 14])\n"
172175
]
173176
}
174177
],
@@ -186,10 +189,11 @@
186189
" \"intervention\": pv.CollectIntervention()}, model=gpt2)\n",
187190
"\n",
188191
"base = \"When John and Mary went to the shops, Mary gave the bag to\"\n",
189-
"collected_attn_w = pv_gpt2(\n",
192+
"(_, collected_attn_w), _ = pv_gpt2(\n",
190193
" base = tokenizer(base, return_tensors=\"pt\"\n",
191194
" ), unit_locations={\"base\": [h for h in range(12)]}\n",
192-
")[0][-1][0]"
195+
")\n",
196+
"print(collected_attn_w[0].shape)"
193197
]
194198
},
195199
{
@@ -1753,15 +1757,31 @@
17531757
},
17541758
{
17551759
"cell_type": "code",
1756-
"execution_count": 28,
1760+
"execution_count": 1,
17571761
"id": "61cd8fc9",
1758-
"metadata": {},
1762+
"metadata": {
1763+
"metadata": {}
1764+
},
17591765
"outputs": [
17601766
{
17611767
"name": "stdout",
17621768
"output_type": "stream",
17631769
"text": [
1764-
"loaded model\n",
1770+
"loaded GPT2 model gpt2\n"
1771+
]
1772+
},
1773+
{
1774+
"name": "stderr",
1775+
"output_type": "stream",
1776+
"text": [
1777+
"/juice/scr/nathangk/text-intervention/pyvene/pyvene/models/intervenable_base.py:796: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
1778+
" cached_activations = torch.tensor(self.activations[key])\n"
1779+
]
1780+
},
1781+
{
1782+
"name": "stdout",
1783+
"output_type": "stream",
1784+
"text": [
17651785
"True True\n"
17661786
]
17671787
}
@@ -2017,9 +2037,11 @@
20172037
},
20182038
{
20192039
"cell_type": "code",
2020-
"execution_count": 31,
2040+
"execution_count": 2,
20212041
"id": "acce6e8f",
2022-
"metadata": {},
2042+
"metadata": {
2043+
"metadata": {}
2044+
},
20232045
"outputs": [
20242046
{
20252047
"name": "stderr",
@@ -2028,6 +2050,34 @@
20282050
"You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
20292051
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
20302052
]
2053+
},
2054+
{
2055+
"data": {
2056+
"application/vnd.jupyter.widget-view+json": {
2057+
"model_id": "517de63768da4f7f8f58e5018c6f75f6",
2058+
"version_major": 2,
2059+
"version_minor": 0
2060+
},
2061+
"text/plain": [
2062+
"model.safetensors: 0%| | 0.00/308M [00:00<?, ?B/s]"
2063+
]
2064+
},
2065+
"metadata": {},
2066+
"output_type": "display_data"
2067+
},
2068+
{
2069+
"data": {
2070+
"application/vnd.jupyter.widget-view+json": {
2071+
"model_id": "378463b4bd174067b6269d64a1ddf1fe",
2072+
"version_major": 2,
2073+
"version_minor": 0
2074+
},
2075+
"text/plain": [
2076+
"generation_config.json: 0%| | 0.00/147 [00:00<?, ?B/s]"
2077+
]
2078+
},
2079+
"metadata": {},
2080+
"output_type": "display_data"
20312081
}
20322082
],
20332083
"source": [
@@ -2676,7 +2726,7 @@
26762726
"name": "python",
26772727
"nbconvert_exporter": "python",
26782728
"pygments_lexer": "ipython3",
2679-
"version": "3.10.13"
2729+
"version": "3.11.7"
26802730
},
26812731
"toc-autonumbering": true,
26822732
"toc-showcode": false,

0 commit comments

Comments
 (0)