Skip to content

Commit 0e3ecb2

Browse files
committed
fix negative_subspace, add intervention nulling with locations
1 parent 38b58d4 commit 0e3ecb2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tests/integration_tests/IntervenableBasicTestCase.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class IntervenableBasicTestCase(unittest.TestCase):
1616
def setUpClass(cls):
1717
_uuid = str(uuid.uuid4())[:6]
1818
cls._test_dir = os.path.join(f"./test_output_dir_prefix-{_uuid}")
19+
cls.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
1920

2021
def test_lazy_demo(self):
2122

@@ -649,10 +650,10 @@ def test_customized_intervention_function_zeroout(self):
649650
def test_nulling_intervention(self):
650651

651652
_, tokenizer, gpt2 = pv.create_gpt2()
652-
gpt2.to("cuda")
653+
gpt2.to(self.DEVICE)
653654
base = tokenizer(
654655
["The capital of Spain is" for i in range(3)], return_tensors="pt"
655-
).to("cuda")
656+
).to(self.DEVICE)
656657

657658
base_output = gpt2(**base)
658659
base_logits = pv.embed_to_distrib(
@@ -668,14 +669,14 @@ def test_nulling_intervention(self):
668669
},
669670
model=gpt2,
670671
)
671-
pv_gpt2.set_device("cuda")
672+
pv_gpt2.set_device(self.DEVICE)
672673

673674
_, intervened_outputs = pv_gpt2(
674675
# the base input
675676
base=base,
676677
# the source input
677678
sources=tokenizer(["Egypt" for i in range(3)], return_tensors="pt").to(
678-
"cuda"
679+
self.DEVICE
679680
),
680681
# the location to intervene at (3rd token)
681682
unit_locations={"sources->base": (0, [[[3], None, [3]]])},

0 commit comments

Comments
 (0)