Skip to content

Commit 5de9727

Browse files
committed
Filter outliers
Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent 37a238e commit 5de9727

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/cett.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,16 @@ def calculate_threshold(neuron_outputs, cett_target, n_thresholds=10000):
154154
min_value = norms.min()
155155
max_value = norms.quantile(0.99)
156156
threshold_grid = torch.linspace(min_value, max_value, n_thresholds)
157-
#initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm)
158-
#outlier_mask = initial_cett > cett_target
157+
max_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm)
158+
outlier_mask = max_cett > cett_target
159159

160160
left = 0
161161
right = n_thresholds
162162
while left < right:
163163
print(left,right)
164164
mid = (left + right) // 2
165-
#cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm)[outlier_mask].mean()
166-
cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm).mean()
165+
cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm) # Compute CETT for each token
166+
cett = cett[outlier_mask].mean() # Remove outliers and take average
167167
if cett <= cett_target:
168168
left = mid + 1
169169
else:

0 commit comments

Comments
 (0)